mirror of
https://github.com/huggingface/diffusers.git
synced 2026-04-11 18:22:04 +08:00
Compare commits
8 Commits
fix-review
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dc8d903217 | ||
|
|
5a9a941a89 | ||
|
|
87beae7771 | ||
|
|
251676dfda | ||
|
|
896fec351b | ||
|
|
4548e68e80 | ||
|
|
b80d3f6872 | ||
|
|
acc07f5cda |
1
.github/workflows/pr_dependency_test.yml
vendored
1
.github/workflows/pr_dependency_test.yml
vendored
@@ -6,6 +6,7 @@ on:
|
||||
- main
|
||||
paths:
|
||||
- "src/diffusers/**.py"
|
||||
- "tests/**.py"
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
@@ -6,6 +6,7 @@ on:
|
||||
- main
|
||||
paths:
|
||||
- "src/diffusers/**.py"
|
||||
- "tests/**.py"
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
@@ -26,7 +27,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install -e .
|
||||
pip install torch torchvision torchaudio pytest
|
||||
pip install torch pytest
|
||||
- name: Check for soft dependencies
|
||||
run: |
|
||||
pytest tests/others/test_dependencies.py
|
||||
|
||||
@@ -350,6 +350,8 @@
|
||||
title: DiTTransformer2DModel
|
||||
- local: api/models/easyanimate_transformer3d
|
||||
title: EasyAnimateTransformer3DModel
|
||||
- local: api/models/ernie_image_transformer2d
|
||||
title: ErnieImageTransformer2DModel
|
||||
- local: api/models/flux2_transformer
|
||||
title: Flux2Transformer2DModel
|
||||
- local: api/models/flux_transformer
|
||||
@@ -534,6 +536,8 @@
|
||||
title: DiT
|
||||
- local: api/pipelines/easyanimate
|
||||
title: EasyAnimate
|
||||
- local: api/pipelines/ernie_image
|
||||
title: ERNIE-Image
|
||||
- local: api/pipelines/flux
|
||||
title: Flux
|
||||
- local: api/pipelines/flux2
|
||||
|
||||
21
docs/source/en/api/models/ernie_image_transformer2d.md
Normal file
21
docs/source/en/api/models/ernie_image_transformer2d.md
Normal file
@@ -0,0 +1,21 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# ErnieImageTransformer2DModel
|
||||
|
||||
A Transformer model for image-like data from [ERNIE-Image](https://huggingface.co/baidu/ERNIE-Image).
|
||||
|
||||
A Transformer model for image-like data from [ERNIE-Image-Turbo](https://huggingface.co/baidu/ERNIE-Image-Turbo).
|
||||
|
||||
## ErnieImageTransformer2DModel
|
||||
|
||||
[[autodoc]] ErnieImageTransformer2DModel
|
||||
86
docs/source/en/api/pipelines/ernie_image.md
Normal file
86
docs/source/en/api/pipelines/ernie_image.md
Normal file
@@ -0,0 +1,86 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Ernie-Image
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
</div>
|
||||
|
||||
[ERNIE-Image] is a powerful and highly efficient image generation model with 8B parameters. Currently there's only two models to be released:
|
||||
|
||||
|Model|Hugging Face|
|
||||
|---|---|
|
||||
|ERNIE-Image|https://huggingface.co/baidu/ERNIE-Image|
|
||||
|ERNIE-Image-Turbo|https://huggingface.co/baidu/ERNIE-Image-Turbo|
|
||||
|
||||
## ERNIE-Image
|
||||
|
||||
ERNIE-Image is designed with a relatively compact architecture and solid instruction-following capability, emphasizing parameter efficiency. Based on an 8B DiT backbone, it provides performance that is comparable in some scenarios to larger (20B+) models, while maintaining reasonable parameter efficiency. It offers a relatively stable level of performance in instruction understanding and execution, text generation (e.g., English / Chinese / Japanese), and overall stability.
|
||||
|
||||
## ERNIE-Image-Turbo
|
||||
|
||||
ERNIE-Image-Turbo is a distilled variant of ERNIE-Image, requiring only 8 NFEs (Number of Function Evaluations) and offering a more efficient alternative with relatively comparable performance to the full model in certain cases.
|
||||
|
||||
## ErnieImagePipeline
|
||||
|
||||
Use [ErnieImagePipeline] to generate images from text prompts. The pipeline supports Prompt Enhancer (PE) by default, which enhances the user’s raw prompt to improve output quality, though it may reduce instruction-following accuracy.
|
||||
|
||||
We provide a pretrained 3B-parameter PE model; however, using larger language models (e.g., Gemini or ChatGPT) for prompt enhancement may yield better results. The system prompt template is available at: https://huggingface.co/baidu/ERNIE-Image/blob/main/pe/chat_template.jinja.
|
||||
|
||||
If you prefer not to use PE, set use_pe=False.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import ErnieImagePipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
pipe = ErnieImagePipeline.from_pretrained("baidu/ERNIE-Image", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
# If you are running low on GPU VRAM, you can enable offloading
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = "一只黑白相间的中华田园犬"
|
||||
images = pipe(
|
||||
prompt=prompt,
|
||||
height=1024,
|
||||
width=1024,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=4.0,
|
||||
generator=torch.Generator("cuda").manual_seed(42),
|
||||
use_pe=True,
|
||||
).images
|
||||
images[0].save("ernie-image-output.png")
|
||||
```
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import ErnieImagePipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
pipe = ErnieImagePipeline.from_pretrained("baidu/ERNIE-Image-Turbo", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
# If you are running low on GPU VRAM, you can enable offloading
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = "一只黑白相间的中华田园犬"
|
||||
images = pipe(
|
||||
prompt=prompt,
|
||||
height=1024,
|
||||
width=1024,
|
||||
num_inference_steps=8,
|
||||
guidance_scale=1.0,
|
||||
generator=torch.Generator("cuda").manual_seed(42),
|
||||
use_pe=True,
|
||||
).images
|
||||
images[0].save("ernie-image-turbo-output.png")
|
||||
```
|
||||
@@ -101,9 +101,9 @@ export_to_video(video, "output.mp4", fps=16)
|
||||
|
||||
## LoRA
|
||||
|
||||
Adapters insert a small number of trainable parameters to the original base model. Only the inserted parameters are fine-tuned while the rest of the model weights remain frozen. This makes it fast and cheap to fine-tune a model on a new style. Among adapters, [LoRA's](./tutorials/using_peft_for_inference) are the most popular.
|
||||
Adapters insert a small number of trainable parameters to the original base model. Only the inserted parameters are fine-tuned while the rest of the model weights remain frozen. This makes it fast and cheap to fine-tune a model on a new style. Among adapters, [LoRAs](./tutorials/using_peft_for_inference) are the most popular.
|
||||
|
||||
Add a LoRA to a pipeline with the [`~loaders.QwenImageLoraLoaderMixin.load_lora_weights`] method. Some LoRA's require a special word to trigger it, such as `Realism`, in the example below. Check a LoRA's model card to see if it requires a trigger word.
|
||||
Add a LoRA to a pipeline with the [`~loaders.QwenImageLoraLoaderMixin.load_lora_weights`] method. Some LoRAs require a special word to trigger them, such as `Realism`, in the example below. Check a LoRA's model card to see if it requires a trigger word.
|
||||
|
||||
```py
|
||||
import torch
|
||||
|
||||
@@ -906,6 +906,68 @@ class PromptDataset(Dataset):
|
||||
return example
|
||||
|
||||
|
||||
# These helpers only matter for prior preservation, where instance and class prompt
|
||||
# embedding batches are concatenated and may not share the same mask/sequence length.
|
||||
def _materialize_prompt_embedding_mask(
|
||||
prompt_embeds: torch.Tensor, prompt_embeds_mask: torch.Tensor | None
|
||||
) -> torch.Tensor:
|
||||
"""Return a dense mask tensor for a prompt embedding batch."""
|
||||
batch_size, seq_len = prompt_embeds.shape[:2]
|
||||
|
||||
if prompt_embeds_mask is None:
|
||||
return torch.ones((batch_size, seq_len), dtype=torch.long, device=prompt_embeds.device)
|
||||
|
||||
if prompt_embeds_mask.shape != (batch_size, seq_len):
|
||||
raise ValueError(
|
||||
f"`prompt_embeds_mask` shape {prompt_embeds_mask.shape} must match prompt embeddings shape "
|
||||
f"({batch_size}, {seq_len})."
|
||||
)
|
||||
|
||||
return prompt_embeds_mask.to(device=prompt_embeds.device)
|
||||
|
||||
|
||||
def _pad_prompt_embedding_pair(
|
||||
prompt_embeds: torch.Tensor, prompt_embeds_mask: torch.Tensor | None, target_seq_len: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Pad one prompt embedding batch and its mask to a shared sequence length."""
|
||||
prompt_embeds_mask = _materialize_prompt_embedding_mask(prompt_embeds, prompt_embeds_mask)
|
||||
pad_width = target_seq_len - prompt_embeds.shape[1]
|
||||
|
||||
if pad_width <= 0:
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
prompt_embeds = torch.cat(
|
||||
[prompt_embeds, prompt_embeds.new_zeros(prompt_embeds.shape[0], pad_width, prompt_embeds.shape[2])], dim=1
|
||||
)
|
||||
prompt_embeds_mask = torch.cat(
|
||||
[prompt_embeds_mask, prompt_embeds_mask.new_zeros(prompt_embeds_mask.shape[0], pad_width)], dim=1
|
||||
)
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
|
||||
def concat_prompt_embedding_batches(
|
||||
*prompt_embedding_pairs: tuple[torch.Tensor, torch.Tensor | None],
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""Concatenate prompt embedding batches while handling missing masks and length mismatches."""
|
||||
if not prompt_embedding_pairs:
|
||||
raise ValueError("At least one prompt embedding pair must be provided.")
|
||||
|
||||
target_seq_len = max(prompt_embeds.shape[1] for prompt_embeds, _ in prompt_embedding_pairs)
|
||||
padded_pairs = [
|
||||
_pad_prompt_embedding_pair(prompt_embeds, prompt_embeds_mask, target_seq_len)
|
||||
for prompt_embeds, prompt_embeds_mask in prompt_embedding_pairs
|
||||
]
|
||||
|
||||
merged_prompt_embeds = torch.cat([prompt_embeds for prompt_embeds, _ in padded_pairs], dim=0)
|
||||
merged_mask = torch.cat([prompt_embeds_mask for _, prompt_embeds_mask in padded_pairs], dim=0)
|
||||
|
||||
if merged_mask.all():
|
||||
return merged_prompt_embeds, None
|
||||
|
||||
return merged_prompt_embeds, merged_mask
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.report_to == "wandb" and args.hub_token is not None:
|
||||
raise ValueError(
|
||||
@@ -1320,8 +1382,10 @@ def main(args):
|
||||
prompt_embeds = instance_prompt_embeds
|
||||
prompt_embeds_mask = instance_prompt_embeds_mask
|
||||
if args.with_prior_preservation:
|
||||
prompt_embeds = torch.cat([prompt_embeds, class_prompt_embeds], dim=0)
|
||||
prompt_embeds_mask = torch.cat([prompt_embeds_mask, class_prompt_embeds_mask], dim=0)
|
||||
prompt_embeds, prompt_embeds_mask = concat_prompt_embedding_batches(
|
||||
(instance_prompt_embeds, instance_prompt_embeds_mask),
|
||||
(class_prompt_embeds, class_prompt_embeds_mask),
|
||||
)
|
||||
|
||||
# if cache_latents is set to True, we encode images to latents and store them.
|
||||
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
|
||||
@@ -1465,7 +1529,10 @@ def main(args):
|
||||
prompt_embeds = prompt_embeds_cache[step]
|
||||
prompt_embeds_mask = prompt_embeds_mask_cache[step]
|
||||
else:
|
||||
num_repeat_elements = len(prompts)
|
||||
# With prior preservation, prompt_embeds already contains [instance, class] embeddings
|
||||
# from the cat above, but collate_fn also doubles the prompts list. Use half the
|
||||
# prompts count to avoid a 2x over-repeat that produces more embeddings than latents.
|
||||
num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts)
|
||||
prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1)
|
||||
if prompt_embeds_mask is not None:
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(num_repeat_elements, 1)
|
||||
|
||||
@@ -235,6 +235,7 @@ else:
|
||||
"CosmosTransformer3DModel",
|
||||
"DiTTransformer2DModel",
|
||||
"EasyAnimateTransformer3DModel",
|
||||
"ErnieImageTransformer2DModel",
|
||||
"Flux2Transformer2DModel",
|
||||
"FluxControlNetModel",
|
||||
"FluxMultiControlNetModel",
|
||||
@@ -455,6 +456,8 @@ else:
|
||||
"HeliosPyramidDistilledAutoBlocks",
|
||||
"HeliosPyramidDistilledModularPipeline",
|
||||
"HeliosPyramidModularPipeline",
|
||||
"LTXAutoBlocks",
|
||||
"LTXModularPipeline",
|
||||
"QwenImageAutoBlocks",
|
||||
"QwenImageEditAutoBlocks",
|
||||
"QwenImageEditModularPipeline",
|
||||
@@ -525,6 +528,7 @@ else:
|
||||
"EasyAnimateControlPipeline",
|
||||
"EasyAnimateInpaintPipeline",
|
||||
"EasyAnimatePipeline",
|
||||
"ErnieImagePipeline",
|
||||
"Flux2KleinKVPipeline",
|
||||
"Flux2KleinPipeline",
|
||||
"Flux2Pipeline",
|
||||
@@ -1035,6 +1039,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
CosmosTransformer3DModel,
|
||||
DiTTransformer2DModel,
|
||||
EasyAnimateTransformer3DModel,
|
||||
ErnieImageTransformer2DModel,
|
||||
Flux2Transformer2DModel,
|
||||
FluxControlNetModel,
|
||||
FluxMultiControlNetModel,
|
||||
@@ -1234,6 +1239,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
HeliosPyramidDistilledAutoBlocks,
|
||||
HeliosPyramidDistilledModularPipeline,
|
||||
HeliosPyramidModularPipeline,
|
||||
LTXAutoBlocks,
|
||||
LTXModularPipeline,
|
||||
QwenImageAutoBlocks,
|
||||
QwenImageEditAutoBlocks,
|
||||
QwenImageEditModularPipeline,
|
||||
@@ -1300,6 +1307,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
EasyAnimateControlPipeline,
|
||||
EasyAnimateInpaintPipeline,
|
||||
EasyAnimatePipeline,
|
||||
ErnieImagePipeline,
|
||||
Flux2KleinKVPipeline,
|
||||
Flux2KleinPipeline,
|
||||
Flux2Pipeline,
|
||||
|
||||
@@ -101,6 +101,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_cosmos"] = ["CosmosTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_ernie_image"] = ["ErnieImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_flux2"] = ["Flux2Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_glm_image"] = ["GlmImageTransformer2DModel"]
|
||||
@@ -219,6 +220,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
DiTTransformer2DModel,
|
||||
DualTransformer2DModel,
|
||||
EasyAnimateTransformer3DModel,
|
||||
ErnieImageTransformer2DModel,
|
||||
Flux2Transformer2DModel,
|
||||
FluxTransformer2DModel,
|
||||
GlmImageTransformer2DModel,
|
||||
|
||||
@@ -25,6 +25,7 @@ if is_torch_available():
|
||||
from .transformer_cogview4 import CogView4Transformer2DModel
|
||||
from .transformer_cosmos import CosmosTransformer3DModel
|
||||
from .transformer_easyanimate import EasyAnimateTransformer3DModel
|
||||
from .transformer_ernie_image import ErnieImageTransformer2DModel
|
||||
from .transformer_flux import FluxTransformer2DModel
|
||||
from .transformer_flux2 import Flux2Transformer2DModel
|
||||
from .transformer_glm_image import GlmImageTransformer2DModel
|
||||
|
||||
430
src/diffusers/models/transformers/transformer_ernie_image.py
Normal file
430
src/diffusers/models/transformers/transformer_ernie_image.py
Normal file
@@ -0,0 +1,430 @@
|
||||
# Copyright 2025 Baidu ERNIE-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Ernie-Image Transformer2DModel for HuggingFace Diffusers.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import BaseOutput, logging
|
||||
from ..attention import AttentionModuleMixin
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..attention_processor import Attention
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import RMSNorm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class ErnieImageTransformer2DModelOutput(BaseOutput):
|
||||
sample: torch.Tensor
|
||||
|
||||
|
||||
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
||||
assert dim % 2 == 0
|
||||
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
||||
omega = 1.0 / (theta**scale)
|
||||
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||
return out.float()
|
||||
|
||||
|
||||
class ErnieImageEmbedND3(nn.Module):
|
||||
def __init__(self, dim: int, theta: int, axes_dim: Tuple[int, int, int]):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
self.axes_dim = list(axes_dim)
|
||||
|
||||
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(3)], dim=-1)
|
||||
emb = emb.unsqueeze(2) # [B, S, 1, head_dim//2]
|
||||
return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1) # [B, S, 1, head_dim]
|
||||
|
||||
|
||||
class ErnieImagePatchEmbedDynamic(nn.Module):
|
||||
def __init__(self, in_channels: int, embed_dim: int, patch_size: int):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.proj(x)
|
||||
batch_size, dim, height, width = x.shape
|
||||
return x.reshape(batch_size, dim, height * width).transpose(1, 2).contiguous()
|
||||
|
||||
|
||||
class ErnieImageSingleStreamAttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"ErnieImageSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
freqs_cis: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
query = query.unflatten(-1, (attn.heads, -1))
|
||||
key = key.unflatten(-1, (attn.heads, -1))
|
||||
value = value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
# Apply Norms
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE: same rotate_half logic as Megatron _apply_rotary_pos_emb_bshd (rotary_interleaved=False)
|
||||
# x_in: [B, S, heads, head_dim], freqs_cis: [B, S, 1, head_dim] with angles [θ0,θ0,θ1,θ1,...]
|
||||
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
rot_dim = freqs_cis.shape[-1]
|
||||
x, x_pass = x_in[..., :rot_dim], x_in[..., rot_dim:]
|
||||
cos_ = torch.cos(freqs_cis).to(x.dtype)
|
||||
sin_ = torch.sin(freqs_cis).to(x.dtype)
|
||||
# Non-interleaved rotate_half: [-x2, x1]
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
x_rotated = torch.cat((-x2, x1), dim=-1)
|
||||
return torch.cat((x * cos_ + x_rotated * sin_, x_pass), dim=-1)
|
||||
|
||||
if freqs_cis is not None:
|
||||
query = apply_rotary_emb(query, freqs_cis)
|
||||
key = apply_rotary_emb(key, freqs_cis)
|
||||
|
||||
# Cast to correct dtype
|
||||
dtype = query.dtype
|
||||
query, key = query.to(dtype), key.to(dtype)
|
||||
|
||||
# From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len]
|
||||
if attention_mask is not None and attention_mask.ndim == 2:
|
||||
attention_mask = attention_mask[:, None, None, :]
|
||||
|
||||
# Compute joint attention
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
|
||||
# Reshape back
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
output = attn.to_out[0](hidden_states)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class ErnieImageAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
_default_processor_cls = ErnieImageSingleStreamAttnProcessor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
bias: bool = False,
|
||||
qk_norm: str = "rms_norm",
|
||||
added_proj_bias: bool | None = True,
|
||||
out_bias: bool = True,
|
||||
eps: float = 1e-5,
|
||||
out_dim: int = None,
|
||||
elementwise_affine: bool = True,
|
||||
processor=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.head_dim = dim_head
|
||||
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||
self.query_dim = query_dim
|
||||
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||
self.heads = out_dim // dim_head if out_dim is not None else heads
|
||||
|
||||
self.use_bias = bias
|
||||
self.dropout = dropout
|
||||
|
||||
self.added_proj_bias = added_proj_bias
|
||||
|
||||
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
|
||||
# QK Norm
|
||||
if qk_norm == "layer_norm":
|
||||
self.norm_q = torch.nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.norm_k = torch.nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
elif qk_norm == "rms_norm":
|
||||
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'fp32_layer_norm', 'layer_norm_across_heads', 'rms_norm', 'rms_norm_across_heads', 'l2'."
|
||||
)
|
||||
|
||||
self.to_out = torch.nn.ModuleList([])
|
||||
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
||||
|
||||
if processor is None:
|
||||
processor = self._default_processor_cls()
|
||||
self.set_processor(processor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
image_rotary_emb: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
||||
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters]
|
||||
if len(unused_kwargs) > 0:
|
||||
logger.warning(
|
||||
f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
|
||||
)
|
||||
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
|
||||
return self.processor(self, hidden_states, attention_mask, image_rotary_emb, **kwargs)
|
||||
|
||||
|
||||
class ErnieImageFeedForward(nn.Module):
|
||||
def __init__(self, hidden_size: int, ffn_hidden_size: int):
|
||||
super().__init__()
|
||||
# Separate gate and up projections (matches converted weights)
|
||||
self.gate_proj = nn.Linear(hidden_size, ffn_hidden_size, bias=False)
|
||||
self.up_proj = nn.Linear(hidden_size, ffn_hidden_size, bias=False)
|
||||
self.linear_fc2 = nn.Linear(ffn_hidden_size, hidden_size, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.linear_fc2(self.up_proj(x) * F.gelu(self.gate_proj(x)))
|
||||
|
||||
|
||||
class ErnieImageSharedAdaLNBlock(nn.Module):
|
||||
def __init__(
|
||||
self, hidden_size: int, num_heads: int, ffn_hidden_size: int, eps: float = 1e-6, qk_layernorm: bool = True
|
||||
):
|
||||
super().__init__()
|
||||
self.adaLN_sa_ln = RMSNorm(hidden_size, eps=eps)
|
||||
self.self_attention = ErnieImageAttention(
|
||||
query_dim=hidden_size,
|
||||
dim_head=hidden_size // num_heads,
|
||||
heads=num_heads,
|
||||
qk_norm="rms_norm" if qk_layernorm else None,
|
||||
eps=eps,
|
||||
bias=False,
|
||||
out_bias=False,
|
||||
processor=ErnieImageSingleStreamAttnProcessor(),
|
||||
)
|
||||
self.adaLN_mlp_ln = RMSNorm(hidden_size, eps=eps)
|
||||
self.mlp = ErnieImageFeedForward(hidden_size, ffn_hidden_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
rotary_pos_emb,
|
||||
temb: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = temb
|
||||
residual = x
|
||||
x = self.adaLN_sa_ln(x)
|
||||
x = (x.float() * (1 + scale_msa.float()) + shift_msa.float()).to(x.dtype)
|
||||
x_bsh = x.permute(1, 0, 2) # [S, B, H] → [B, S, H] for diffusers Attention (batch-first)
|
||||
attn_out = self.self_attention(x_bsh, attention_mask=attention_mask, image_rotary_emb=rotary_pos_emb)
|
||||
attn_out = attn_out.permute(1, 0, 2) # [B, S, H] → [S, B, H]
|
||||
x = residual + (gate_msa.float() * attn_out.float()).to(x.dtype)
|
||||
residual = x
|
||||
x = self.adaLN_mlp_ln(x)
|
||||
x = (x.float() * (1 + scale_mlp.float()) + shift_mlp.float()).to(x.dtype)
|
||||
return residual + (gate_mlp.float() * self.mlp(x).float()).to(x.dtype)
|
||||
|
||||
|
||||
class ErnieImageAdaLNContinuous(nn.Module):
|
||||
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=eps)
|
||||
self.linear = nn.Linear(hidden_size, hidden_size * 2)
|
||||
|
||||
def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
|
||||
scale, shift = self.linear(conditioning).chunk(2, dim=-1)
|
||||
x = self.norm(x)
|
||||
# Broadcast conditioning to sequence dimension
|
||||
x = x * (1 + scale.unsqueeze(0)) + shift.unsqueeze(0)
|
||||
return x
|
||||
|
||||
|
||||
class ErnieImageTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 3072,
|
||||
num_attention_heads: int = 24,
|
||||
num_layers: int = 24,
|
||||
ffn_hidden_size: int = 8192,
|
||||
in_channels: int = 128,
|
||||
out_channels: int = 128,
|
||||
patch_size: int = 1,
|
||||
text_in_dim: int = 2560,
|
||||
rope_theta: int = 256,
|
||||
rope_axes_dim: Tuple[int, int, int] = (32, 48, 48),
|
||||
eps: float = 1e-6,
|
||||
qk_layernorm: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_attention_heads
|
||||
self.head_dim = hidden_size // num_attention_heads
|
||||
self.num_layers = num_layers
|
||||
self.patch_size = patch_size
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.text_in_dim = text_in_dim
|
||||
|
||||
self.x_embedder = ErnieImagePatchEmbedDynamic(in_channels, hidden_size, patch_size)
|
||||
self.text_proj = nn.Linear(text_in_dim, hidden_size, bias=False) if text_in_dim != hidden_size else None
|
||||
self.time_proj = Timesteps(hidden_size, flip_sin_to_cos=False, downscale_freq_shift=0)
|
||||
self.time_embedding = TimestepEmbedding(hidden_size, hidden_size)
|
||||
self.pos_embed = ErnieImageEmbedND3(dim=self.head_dim, theta=rope_theta, axes_dim=rope_axes_dim)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size))
|
||||
nn.init.zeros_(self.adaLN_modulation[-1].weight)
|
||||
nn.init.zeros_(self.adaLN_modulation[-1].bias)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
ErnieImageSharedAdaLNBlock(
|
||||
hidden_size, num_attention_heads, ffn_hidden_size, eps, qk_layernorm=qk_layernorm
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.final_norm = ErnieImageAdaLNContinuous(hidden_size, eps)
|
||||
self.final_linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels)
|
||||
nn.init.zeros_(self.final_linear.weight)
|
||||
nn.init.zeros_(self.final_linear.bias)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
# encoder_hidden_states: List[torch.Tensor],
|
||||
text_bth: torch.Tensor,
|
||||
text_lens: torch.Tensor,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
device, dtype = hidden_states.device, hidden_states.dtype
|
||||
B, C, H, W = hidden_states.shape
|
||||
p, Hp, Wp = self.patch_size, H // self.patch_size, W // self.patch_size
|
||||
N_img = Hp * Wp
|
||||
|
||||
img_sbh = self.x_embedder(hidden_states).transpose(0, 1).contiguous()
|
||||
# text_bth, text_lens = self._pad_text(encoder_hidden_states, device, dtype)
|
||||
if self.text_proj is not None and text_bth.numel() > 0:
|
||||
text_bth = self.text_proj(text_bth)
|
||||
Tmax = text_bth.shape[1]
|
||||
text_sbh = text_bth.transpose(0, 1).contiguous()
|
||||
|
||||
x = torch.cat([img_sbh, text_sbh], dim=0)
|
||||
S = x.shape[0]
|
||||
|
||||
# Position IDs
|
||||
text_ids = (
|
||||
torch.cat(
|
||||
[
|
||||
torch.arange(Tmax, device=device, dtype=torch.float32).view(1, Tmax, 1).expand(B, -1, -1),
|
||||
torch.zeros((B, Tmax, 2), device=device),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
if Tmax > 0
|
||||
else torch.zeros((B, 0, 3), device=device)
|
||||
)
|
||||
grid_yx = torch.stack(
|
||||
torch.meshgrid(
|
||||
torch.arange(Hp, device=device, dtype=torch.float32),
|
||||
torch.arange(Wp, device=device, dtype=torch.float32),
|
||||
indexing="ij",
|
||||
),
|
||||
dim=-1,
|
||||
).reshape(-1, 2)
|
||||
image_ids = torch.cat(
|
||||
[text_lens.float().view(B, 1, 1).expand(-1, N_img, -1), grid_yx.view(1, N_img, 2).expand(B, -1, -1)],
|
||||
dim=-1,
|
||||
)
|
||||
rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1))
|
||||
|
||||
# Attention mask: True = valid (attend), False = padding (mask out), matches sdpa bool convention
|
||||
valid_text = (
|
||||
torch.arange(Tmax, device=device).view(1, Tmax) < text_lens.view(B, 1)
|
||||
if Tmax > 0
|
||||
else torch.zeros((B, 0), device=device, dtype=torch.bool)
|
||||
)
|
||||
attention_mask = torch.cat([torch.ones((B, N_img), device=device, dtype=torch.bool), valid_text], dim=1)[
|
||||
:, None, None, :
|
||||
]
|
||||
|
||||
# AdaLN
|
||||
sample = self.time_proj(timestep.to(dtype))
|
||||
sample = sample.to(self.time_embedding.linear_1.weight.dtype)
|
||||
c = self.time_embedding(sample)
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [
|
||||
t.unsqueeze(0).expand(S, -1, -1).contiguous() for t in self.adaLN_modulation(c).chunk(6, dim=-1)
|
||||
]
|
||||
for layer in self.layers:
|
||||
temb = [shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp]
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
x = self._gradient_checkpointing_func(
|
||||
layer,
|
||||
x,
|
||||
rotary_pos_emb,
|
||||
temb,
|
||||
attention_mask,
|
||||
)
|
||||
else:
|
||||
x = layer(x, rotary_pos_emb, temb, attention_mask)
|
||||
x = self.final_norm(x, c).type_as(x)
|
||||
patches = self.final_linear(x)[:N_img].transpose(0, 1).contiguous()
|
||||
output = (
|
||||
patches.view(B, Hp, Wp, p, p, self.out_channels)
|
||||
.permute(0, 5, 1, 3, 2, 4)
|
||||
.contiguous()
|
||||
.view(B, self.out_channels, H, W)
|
||||
)
|
||||
|
||||
return ErnieImageTransformer2DModelOutput(sample=output) if return_dict else (output,)
|
||||
@@ -233,6 +233,11 @@ class QwenEmbedRope(nn.Module):
|
||||
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs
|
||||
|
||||
@lru_cache_unless_export(maxsize=None)
|
||||
def _get_device_freqs(self, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Return pos_freqs and neg_freqs on the given device."""
|
||||
return self.pos_freqs.to(device), self.neg_freqs.to(device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
video_fhw: tuple[int, int, int, list[tuple[int, int, int]]],
|
||||
@@ -300,8 +305,9 @@ class QwenEmbedRope(nn.Module):
|
||||
max_vid_index = max(height, width, max_vid_index)
|
||||
|
||||
max_txt_seq_len_int = int(max_txt_seq_len)
|
||||
# Create device-specific copy for text freqs without modifying self.pos_freqs
|
||||
txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
|
||||
# Use cached device-transferred freqs to avoid CPU→GPU sync every forward call
|
||||
pos_freqs_device, _ = self._get_device_freqs(device)
|
||||
txt_freqs = pos_freqs_device[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
|
||||
vid_freqs = torch.cat(vid_freqs, dim=0)
|
||||
|
||||
return vid_freqs, txt_freqs
|
||||
@@ -311,8 +317,9 @@ class QwenEmbedRope(nn.Module):
|
||||
self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None
|
||||
) -> torch.Tensor:
|
||||
seq_lens = frame * height * width
|
||||
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
|
||||
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
|
||||
pos_freqs, neg_freqs = (
|
||||
self._get_device_freqs(device) if device is not None else (self.pos_freqs, self.neg_freqs)
|
||||
)
|
||||
|
||||
freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
@@ -367,6 +374,11 @@ class QwenEmbedLayer3DRope(nn.Module):
|
||||
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs
|
||||
|
||||
@lru_cache_unless_export(maxsize=None)
|
||||
def _get_device_freqs(self, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Return pos_freqs and neg_freqs on the given device."""
|
||||
return self.pos_freqs.to(device), self.neg_freqs.to(device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
video_fhw: tuple[int, int, int, list[tuple[int, int, int]]],
|
||||
@@ -421,8 +433,9 @@ class QwenEmbedLayer3DRope(nn.Module):
|
||||
|
||||
max_vid_index = max(max_vid_index, layer_num)
|
||||
max_txt_seq_len_int = int(max_txt_seq_len)
|
||||
# Create device-specific copy for text freqs without modifying self.pos_freqs
|
||||
txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
|
||||
# Use cached device-transferred freqs to avoid CPU→GPU sync every forward call
|
||||
pos_freqs_device, _ = self._get_device_freqs(device)
|
||||
txt_freqs = pos_freqs_device[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
|
||||
vid_freqs = torch.cat(vid_freqs, dim=0)
|
||||
|
||||
return vid_freqs, txt_freqs
|
||||
@@ -430,8 +443,9 @@ class QwenEmbedLayer3DRope(nn.Module):
|
||||
@lru_cache_unless_export(maxsize=None)
|
||||
def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None):
|
||||
seq_lens = frame * height * width
|
||||
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
|
||||
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
|
||||
pos_freqs, neg_freqs = (
|
||||
self._get_device_freqs(device) if device is not None else (self.pos_freqs, self.neg_freqs)
|
||||
)
|
||||
|
||||
freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
@@ -452,8 +466,9 @@ class QwenEmbedLayer3DRope(nn.Module):
|
||||
@lru_cache_unless_export(maxsize=None)
|
||||
def _compute_condition_freqs(self, frame, height, width, device: torch.device = None):
|
||||
seq_lens = frame * height * width
|
||||
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
|
||||
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
|
||||
pos_freqs, neg_freqs = (
|
||||
self._get_device_freqs(device) if device is not None else (self.pos_freqs, self.neg_freqs)
|
||||
)
|
||||
|
||||
freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
|
||||
@@ -88,6 +88,10 @@ else:
|
||||
"QwenImageLayeredModularPipeline",
|
||||
"QwenImageLayeredAutoBlocks",
|
||||
]
|
||||
_import_structure["ltx"] = [
|
||||
"LTXAutoBlocks",
|
||||
"LTXModularPipeline",
|
||||
]
|
||||
_import_structure["z_image"] = [
|
||||
"ZImageAutoBlocks",
|
||||
"ZImageModularPipeline",
|
||||
@@ -119,6 +123,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
HeliosPyramidDistilledModularPipeline,
|
||||
HeliosPyramidModularPipeline,
|
||||
)
|
||||
from .ltx import LTXAutoBlocks, LTXModularPipeline
|
||||
from .modular_pipeline import (
|
||||
AutoPipelineBlocks,
|
||||
BlockState,
|
||||
|
||||
47
src/diffusers/modular_pipelines/ltx/__init__.py
Normal file
47
src/diffusers/modular_pipelines/ltx/__init__.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["modular_blocks_ltx"] = ["LTXAutoBlocks", "LTXBlocks", "LTXImage2VideoBlocks"]
|
||||
_import_structure["modular_pipeline"] = ["LTXModularPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .modular_blocks_ltx import LTXAutoBlocks, LTXBlocks, LTXImage2VideoBlocks
|
||||
from .modular_pipeline import LTXModularPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
392
src/diffusers/modular_pipelines/ltx/before_denoise.py
Normal file
392
src/diffusers/modular_pipelines/ltx/before_denoise.py
Normal file
@@ -0,0 +1,392 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import LTXModularPipeline, LTXVideoPachifier
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
max_seq_len: int = 4096,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: int | None = None,
|
||||
device: str | torch.device | None = None,
|
||||
timesteps: list[int] | None = None,
|
||||
sigmas: list[float] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`list[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`list[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class LTXTextInputStep(ModularPipelineBlocks):
|
||||
model_name = "ltx"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Input processing step that:\n"
|
||||
" 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
|
||||
" 2. Adjusts input tensor shapes based on `batch_size` and `num_videos_per_prompt`"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam.template("num_images_per_prompt", name="num_videos_per_prompt"),
|
||||
InputParam.template("prompt_embeds", required=True),
|
||||
InputParam.template("prompt_embeds_mask", name="prompt_attention_mask"),
|
||||
InputParam.template("negative_prompt_embeds"),
|
||||
InputParam.template("negative_prompt_embeds_mask", name="negative_prompt_attention_mask"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam("batch_size", type_hint=int),
|
||||
OutputParam("dtype", type_hint=torch.dtype),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: LTXModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
block_state.batch_size = block_state.prompt_embeds.shape[0]
|
||||
block_state.dtype = block_state.prompt_embeds.dtype
|
||||
num_videos = block_state.num_videos_per_prompt
|
||||
|
||||
# Repeat prompt_embeds for num_videos_per_prompt
|
||||
_, seq_len, _ = block_state.prompt_embeds.shape
|
||||
block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, num_videos, 1)
|
||||
block_state.prompt_embeds = block_state.prompt_embeds.view(block_state.batch_size * num_videos, seq_len, -1)
|
||||
|
||||
if block_state.prompt_attention_mask is not None:
|
||||
block_state.prompt_attention_mask = block_state.prompt_attention_mask.repeat(num_videos, 1)
|
||||
|
||||
if block_state.negative_prompt_embeds is not None:
|
||||
_, seq_len, _ = block_state.negative_prompt_embeds.shape
|
||||
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(1, num_videos, 1)
|
||||
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(
|
||||
block_state.batch_size * num_videos, seq_len, -1
|
||||
)
|
||||
|
||||
if block_state.negative_prompt_attention_mask is not None:
|
||||
block_state.negative_prompt_attention_mask = block_state.negative_prompt_attention_mask.repeat(
|
||||
num_videos, 1
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class LTXSetTimestepsStep(ModularPipelineBlocks):
|
||||
model_name = "ltx"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that sets the scheduler's timesteps for inference"
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam.template("num_inference_steps"),
|
||||
InputParam.template("timesteps"),
|
||||
InputParam.template("sigmas"),
|
||||
InputParam.template("height", default=512),
|
||||
InputParam.template("width", default=704),
|
||||
InputParam("num_frames", type_hint=int, default=161),
|
||||
InputParam("frame_rate", type_hint=int, default=25),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam("timesteps", type_hint=torch.Tensor),
|
||||
OutputParam("num_inference_steps", type_hint=int),
|
||||
OutputParam("rope_interpolation_scale", type_hint=tuple),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: LTXModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
device = components._execution_device
|
||||
|
||||
height = block_state.height
|
||||
width = block_state.width
|
||||
num_frames = block_state.num_frames
|
||||
frame_rate = block_state.frame_rate
|
||||
|
||||
latent_num_frames = (num_frames - 1) // components.vae_temporal_compression_ratio + 1
|
||||
latent_height = height // components.vae_spatial_compression_ratio
|
||||
latent_width = width // components.vae_spatial_compression_ratio
|
||||
video_sequence_length = latent_num_frames * latent_height * latent_width
|
||||
|
||||
custom_timesteps = block_state.timesteps
|
||||
sigmas = block_state.sigmas
|
||||
|
||||
if custom_timesteps is not None:
|
||||
# User provided custom timesteps, don't compute sigmas
|
||||
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
|
||||
components.scheduler,
|
||||
block_state.num_inference_steps,
|
||||
device,
|
||||
custom_timesteps,
|
||||
)
|
||||
else:
|
||||
if sigmas is None:
|
||||
sigmas = np.linspace(1.0, 1 / block_state.num_inference_steps, block_state.num_inference_steps)
|
||||
|
||||
mu = calculate_shift(
|
||||
video_sequence_length,
|
||||
components.scheduler.config.get("base_image_seq_len", 256),
|
||||
components.scheduler.config.get("max_image_seq_len", 4096),
|
||||
components.scheduler.config.get("base_shift", 0.5),
|
||||
components.scheduler.config.get("max_shift", 1.15),
|
||||
)
|
||||
|
||||
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
|
||||
components.scheduler,
|
||||
block_state.num_inference_steps,
|
||||
device,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
|
||||
block_state.rope_interpolation_scale = (
|
||||
components.vae_temporal_compression_ratio / frame_rate,
|
||||
components.vae_spatial_compression_ratio,
|
||||
components.vae_spatial_compression_ratio,
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class LTXPrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "ltx"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Prepare latents step that prepares the latents for the text-to-video generation process"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec(
|
||||
"pachifier",
|
||||
LTXVideoPachifier,
|
||||
config=FrozenDict({"patch_size": 1, "patch_size_t": 1}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam.template("height", default=512),
|
||||
InputParam.template("width", default=704),
|
||||
InputParam("num_frames", type_hint=int, default=161),
|
||||
InputParam.template("latents"),
|
||||
InputParam.template("num_images_per_prompt", name="num_videos_per_prompt"),
|
||||
InputParam.template("generator"),
|
||||
InputParam.template("batch_size", required=True),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam("latents", type_hint=torch.Tensor),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: LTXModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
device = components._execution_device
|
||||
|
||||
batch_size = block_state.batch_size * block_state.num_videos_per_prompt
|
||||
num_channels_latents = components.transformer.config.in_channels
|
||||
|
||||
if block_state.latents is not None:
|
||||
block_state.latents = block_state.latents.to(device=device, dtype=torch.float32)
|
||||
else:
|
||||
height = block_state.height // components.vae_spatial_compression_ratio
|
||||
width = block_state.width // components.vae_spatial_compression_ratio
|
||||
num_frames = (block_state.num_frames - 1) // components.vae_temporal_compression_ratio + 1
|
||||
|
||||
shape = (batch_size, num_channels_latents, num_frames, height, width)
|
||||
block_state.latents = randn_tensor(
|
||||
shape, generator=block_state.generator, device=device, dtype=torch.float32
|
||||
)
|
||||
block_state.latents = components.pachifier.pack_latents(block_state.latents)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class LTXImage2VideoPrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "ltx"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Prepare image-to-video latents: adds noise to pre-encoded image latents and creates a conditioning mask. "
|
||||
"Expects pure noise `latents` from LTXPrepareLatentsStep."
|
||||
)
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec(
|
||||
"pachifier",
|
||||
LTXVideoPachifier,
|
||||
config=FrozenDict({"patch_size": 1, "patch_size_t": 1}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("image_latents", type_hint=torch.Tensor, required=True),
|
||||
InputParam.template("latents", required=True),
|
||||
InputParam.template("height", default=512),
|
||||
InputParam.template("width", default=704),
|
||||
InputParam("num_frames", type_hint=int, default=161),
|
||||
InputParam.template("num_images_per_prompt", name="num_videos_per_prompt"),
|
||||
InputParam.template("batch_size", required=True),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam("latents", type_hint=torch.Tensor),
|
||||
OutputParam("conditioning_mask", type_hint=torch.Tensor),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: LTXModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
device = components._execution_device
|
||||
|
||||
batch_size = block_state.batch_size * block_state.num_videos_per_prompt
|
||||
|
||||
height = block_state.height // components.vae_spatial_compression_ratio
|
||||
width = block_state.width // components.vae_spatial_compression_ratio
|
||||
num_frames = (block_state.num_frames - 1) // components.vae_temporal_compression_ratio + 1
|
||||
|
||||
init_latents = block_state.image_latents.to(device=device, dtype=torch.float32)
|
||||
if init_latents.shape[0] < batch_size:
|
||||
init_latents = init_latents.repeat_interleave(batch_size // init_latents.shape[0], dim=0)
|
||||
init_latents = init_latents.repeat(1, 1, num_frames, 1, 1)
|
||||
|
||||
conditioning_mask = torch.zeros(
|
||||
init_latents.shape[0],
|
||||
1,
|
||||
init_latents.shape[2],
|
||||
init_latents.shape[3],
|
||||
init_latents.shape[4],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
conditioning_mask[:, :, 0] = 1.0
|
||||
|
||||
noise = components.pachifier.unpack_latents(block_state.latents, num_frames, height, width)
|
||||
latents = init_latents * conditioning_mask + noise * (1 - conditioning_mask)
|
||||
|
||||
conditioning_mask = components.pachifier.pack_latents(conditioning_mask).squeeze(-1)
|
||||
latents = components.pachifier.pack_latents(latents)
|
||||
|
||||
block_state.latents = latents
|
||||
block_state.conditioning_mask = conditioning_mask
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
132
src/diffusers/modular_pipelines/ltx/decoders.py
Normal file
132
src/diffusers/modular_pipelines/ltx/decoders.py
Normal file
@@ -0,0 +1,132 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...models import AutoencoderKLLTXVideo
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import LTXVideoPachifier
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _denormalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
# Denormalize latents across the channel dimension [B, C, F, H, W]
|
||||
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = latents * latents_std / scaling_factor + latents_mean
|
||||
return latents
|
||||
|
||||
|
||||
class LTXVaeDecoderStep(ModularPipelineBlocks):
|
||||
model_name = "ltx"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKLLTXVideo),
|
||||
ComponentSpec(
|
||||
"video_processor",
|
||||
VideoProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 32}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
ComponentSpec(
|
||||
"pachifier",
|
||||
LTXVideoPachifier,
|
||||
config=FrozenDict({"patch_size": 1, "patch_size_t": 1}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that decodes the denoised latents into videos"
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[tuple[str, Any]]:
|
||||
return [
|
||||
InputParam.template("latents", required=True),
|
||||
InputParam.template("output_type", default="np"),
|
||||
InputParam.template("height", default=512),
|
||||
InputParam.template("width", default=704),
|
||||
InputParam("num_frames", type_hint=int, default=161),
|
||||
InputParam("decode_timestep", default=0.0),
|
||||
InputParam("decode_noise_scale", default=None),
|
||||
InputParam.template("generator"),
|
||||
InputParam.template("batch_size"),
|
||||
InputParam.template("dtype", required=True),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [OutputParam.template("videos")]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
vae = components.vae
|
||||
|
||||
latents = block_state.latents
|
||||
|
||||
height = block_state.height
|
||||
width = block_state.width
|
||||
num_frames = block_state.num_frames
|
||||
|
||||
latent_num_frames = (num_frames - 1) // components.vae_temporal_compression_ratio + 1
|
||||
latent_height = height // components.vae_spatial_compression_ratio
|
||||
latent_width = width // components.vae_spatial_compression_ratio
|
||||
|
||||
latents = components.pachifier.unpack_latents(latents, latent_num_frames, latent_height, latent_width)
|
||||
latents = _denormalize_latents(latents, vae.latents_mean, vae.latents_std, vae.config.scaling_factor)
|
||||
latents = latents.to(block_state.dtype)
|
||||
|
||||
if not vae.config.timestep_conditioning:
|
||||
timestep = None
|
||||
else:
|
||||
device = latents.device
|
||||
batch_size = block_state.batch_size
|
||||
decode_timestep = block_state.decode_timestep
|
||||
decode_noise_scale = block_state.decode_noise_scale
|
||||
|
||||
noise = randn_tensor(latents.shape, generator=block_state.generator, device=device, dtype=latents.dtype)
|
||||
if not isinstance(decode_timestep, list):
|
||||
decode_timestep = [decode_timestep] * batch_size
|
||||
if decode_noise_scale is None:
|
||||
decode_noise_scale = decode_timestep
|
||||
elif not isinstance(decode_noise_scale, list):
|
||||
decode_noise_scale = [decode_noise_scale] * batch_size
|
||||
|
||||
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
|
||||
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
|
||||
:, None, None, None, None
|
||||
]
|
||||
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
|
||||
|
||||
latents = latents.to(vae.dtype)
|
||||
video = vae.decode(latents, timestep, return_dict=False)[0]
|
||||
block_state.videos = components.video_processor.postprocess_video(video, output_type=block_state.output_type)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
458
src/diffusers/modular_pipelines/ltx/denoise.py
Normal file
458
src/diffusers/modular_pipelines/ltx/denoise.py
Normal file
@@ -0,0 +1,458 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...guiders import ClassifierFreeGuidance
|
||||
from ...models import LTXVideoTransformer3DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ..modular_pipeline import (
|
||||
BlockState,
|
||||
LoopSequentialPipelineBlocks,
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam
|
||||
from .modular_pipeline import LTXModularPipeline, LTXVideoPachifier
|
||||
|
||||
|
||||
class LTXLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
model_name = "ltx"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step within the denoising loop that prepares the latent input for the denoiser. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `LTXDenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam.template("latents", required=True),
|
||||
InputParam.template("dtype", required=True),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: LTXModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
block_state.latent_model_input = block_state.latents.to(block_state.dtype)
|
||||
return components, block_state
|
||||
|
||||
|
||||
class LTXLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "ltx"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guider_input_fields: dict[str, Any] | None = None,
|
||||
):
|
||||
if guider_input_fields is None:
|
||||
guider_input_fields = {
|
||||
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
|
||||
"encoder_attention_mask": ("prompt_attention_mask", "negative_prompt_attention_mask"),
|
||||
}
|
||||
if not isinstance(guider_input_fields, dict):
|
||||
raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}")
|
||||
self._guider_input_fields = guider_input_fields
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 3.0}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
ComponentSpec("transformer", LTXVideoTransformer3DModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step within the denoising loop that denoises the latents with guidance. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `LTXDenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[tuple[str, Any]]:
|
||||
inputs = [
|
||||
InputParam.template("attention_kwargs"),
|
||||
InputParam.template("num_inference_steps", required=True),
|
||||
InputParam("rope_interpolation_scale", type_hint=tuple),
|
||||
InputParam.template("height"),
|
||||
InputParam.template("width"),
|
||||
InputParam("num_frames", type_hint=int),
|
||||
]
|
||||
guider_input_names = []
|
||||
for value in self._guider_input_fields.values():
|
||||
if isinstance(value, tuple):
|
||||
guider_input_names.extend(value)
|
||||
else:
|
||||
guider_input_names.append(value)
|
||||
|
||||
for name in guider_input_names:
|
||||
inputs.append(InputParam(name=name, required=True, type_hint=torch.Tensor))
|
||||
return inputs
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self, components: LTXModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
|
||||
) -> PipelineState:
|
||||
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
|
||||
|
||||
latent_num_frames = (block_state.num_frames - 1) // components.vae_temporal_compression_ratio + 1
|
||||
latent_height = block_state.height // components.vae_spatial_compression_ratio
|
||||
latent_width = block_state.width // components.vae_spatial_compression_ratio
|
||||
|
||||
guider_state = components.guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields)
|
||||
|
||||
for guider_state_batch in guider_state:
|
||||
components.guider.prepare_models(components.transformer)
|
||||
cond_kwargs = guider_state_batch.as_dict()
|
||||
cond_kwargs = {
|
||||
k: v.to(block_state.dtype) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in cond_kwargs.items()
|
||||
if k in self._guider_input_fields.keys()
|
||||
}
|
||||
|
||||
context_name = getattr(guider_state_batch, components.guider._identifier_key, None)
|
||||
with components.transformer.cache_context(context_name):
|
||||
guider_state_batch.noise_pred = components.transformer(
|
||||
hidden_states=block_state.latent_model_input,
|
||||
timestep=t.expand(block_state.latent_model_input.shape[0]).to(block_state.dtype),
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
rope_interpolation_scale=block_state.rope_interpolation_scale,
|
||||
attention_kwargs=block_state.attention_kwargs,
|
||||
return_dict=False,
|
||||
**cond_kwargs,
|
||||
)[0]
|
||||
components.guider.cleanup_models(components.transformer)
|
||||
|
||||
block_state.noise_pred = components.guider(guider_state)[0]
|
||||
|
||||
return components, block_state
|
||||
|
||||
|
||||
class LTXLoopAfterDenoiser(ModularPipelineBlocks):
|
||||
model_name = "ltx"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step within the denoising loop that updates the latents. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `LTXDenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: LTXModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
latents_dtype = block_state.latents.dtype
|
||||
block_state.latents = components.scheduler.step(
|
||||
block_state.noise_pred,
|
||||
t,
|
||||
block_state.latents,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if block_state.latents.dtype != latents_dtype:
|
||||
block_state.latents = block_state.latents.to(latents_dtype)
|
||||
|
||||
return components, block_state
|
||||
|
||||
|
||||
class LTXDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
||||
model_name = "ltx"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Pipeline block that iteratively denoises the latents over `timesteps`. "
|
||||
"The specific steps within each iteration can be customized with `sub_blocks` attributes"
|
||||
)
|
||||
|
||||
@property
|
||||
def loop_expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
|
||||
ComponentSpec("transformer", LTXVideoTransformer3DModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def loop_inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam.template("timesteps", required=True),
|
||||
InputParam.template("num_inference_steps", required=True),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: LTXModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
block_state.num_warmup_steps = max(
|
||||
len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
|
||||
)
|
||||
|
||||
with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(block_state.timesteps):
|
||||
components, block_state = self.loop_step(components, block_state, i=i, t=t)
|
||||
if i == len(block_state.timesteps) - 1 or (
|
||||
(i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0
|
||||
):
|
||||
progress_bar.update()
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class LTXDenoiseStep(LTXDenoiseLoopWrapper):
|
||||
block_classes = [
|
||||
LTXLoopBeforeDenoiser,
|
||||
LTXLoopDenoiser(
|
||||
guider_input_fields={
|
||||
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
|
||||
"encoder_attention_mask": ("prompt_attention_mask", "negative_prompt_attention_mask"),
|
||||
}
|
||||
),
|
||||
LTXLoopAfterDenoiser,
|
||||
]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoises the latents.\n"
|
||||
"Its loop logic is defined in `LTXDenoiseLoopWrapper.__call__` method.\n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
|
||||
" - `LTXLoopBeforeDenoiser`\n"
|
||||
" - `LTXLoopDenoiser`\n"
|
||||
" - `LTXLoopAfterDenoiser`\n"
|
||||
"This block supports text-to-video tasks."
|
||||
)
|
||||
|
||||
|
||||
class LTXImage2VideoLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
model_name = "ltx"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step within the i2v denoising loop that prepares the latent input and modulates "
|
||||
"the timestep with the conditioning mask."
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam.template("latents", required=True),
|
||||
InputParam("conditioning_mask", required=True, type_hint=torch.Tensor),
|
||||
InputParam.template("dtype", required=True),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: LTXModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
block_state.latent_model_input = block_state.latents.to(block_state.dtype)
|
||||
block_state.timestep_adjusted = t.expand(block_state.latent_model_input.shape[0]).unsqueeze(-1) * (
|
||||
1 - block_state.conditioning_mask
|
||||
)
|
||||
return components, block_state
|
||||
|
||||
|
||||
class LTXImage2VideoLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "ltx"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guider_input_fields: dict[str, Any] | None = None,
|
||||
):
|
||||
if guider_input_fields is None:
|
||||
guider_input_fields = {
|
||||
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
|
||||
"encoder_attention_mask": ("prompt_attention_mask", "negative_prompt_attention_mask"),
|
||||
}
|
||||
if not isinstance(guider_input_fields, dict):
|
||||
raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}")
|
||||
self._guider_input_fields = guider_input_fields
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 3.0}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
ComponentSpec("transformer", LTXVideoTransformer3DModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step within the i2v denoising loop that denoises the latents with guidance "
|
||||
"using timestep modulated by the conditioning mask."
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[tuple[str, Any]]:
|
||||
inputs = [
|
||||
InputParam.template("attention_kwargs"),
|
||||
InputParam.template("num_inference_steps", required=True),
|
||||
InputParam("rope_interpolation_scale", type_hint=tuple),
|
||||
InputParam.template("height"),
|
||||
InputParam.template("width"),
|
||||
InputParam("num_frames", type_hint=int),
|
||||
]
|
||||
guider_input_names = []
|
||||
for value in self._guider_input_fields.values():
|
||||
if isinstance(value, tuple):
|
||||
guider_input_names.extend(value)
|
||||
else:
|
||||
guider_input_names.append(value)
|
||||
for name in guider_input_names:
|
||||
inputs.append(InputParam(name=name, required=True, type_hint=torch.Tensor))
|
||||
return inputs
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self, components: LTXModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
|
||||
) -> PipelineState:
|
||||
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
|
||||
|
||||
latent_num_frames = (block_state.num_frames - 1) // components.vae_temporal_compression_ratio + 1
|
||||
latent_height = block_state.height // components.vae_spatial_compression_ratio
|
||||
latent_width = block_state.width // components.vae_spatial_compression_ratio
|
||||
|
||||
guider_state = components.guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields)
|
||||
|
||||
for guider_state_batch in guider_state:
|
||||
components.guider.prepare_models(components.transformer)
|
||||
cond_kwargs = guider_state_batch.as_dict()
|
||||
cond_kwargs = {
|
||||
k: v.to(block_state.dtype) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in cond_kwargs.items()
|
||||
if k in self._guider_input_fields.keys()
|
||||
}
|
||||
|
||||
context_name = getattr(guider_state_batch, components.guider._identifier_key, None)
|
||||
with components.transformer.cache_context(context_name):
|
||||
guider_state_batch.noise_pred = components.transformer(
|
||||
hidden_states=block_state.latent_model_input,
|
||||
timestep=block_state.timestep_adjusted,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
rope_interpolation_scale=block_state.rope_interpolation_scale,
|
||||
attention_kwargs=block_state.attention_kwargs,
|
||||
return_dict=False,
|
||||
**cond_kwargs,
|
||||
)[0]
|
||||
components.guider.cleanup_models(components.transformer)
|
||||
|
||||
block_state.noise_pred = components.guider(guider_state)[0]
|
||||
|
||||
return components, block_state
|
||||
|
||||
|
||||
class LTXImage2VideoLoopAfterDenoiser(ModularPipelineBlocks):
|
||||
model_name = "ltx"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
|
||||
ComponentSpec(
|
||||
"pachifier",
|
||||
LTXVideoPachifier,
|
||||
config=FrozenDict({"patch_size": 1, "patch_size_t": 1}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step within the i2v denoising loop that updates the latents, "
|
||||
"applying the scheduler step only to frames after the first (conditioned) frame."
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam.template("height"),
|
||||
InputParam.template("width"),
|
||||
InputParam("num_frames", type_hint=int),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: LTXModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
latent_num_frames = (block_state.num_frames - 1) // components.vae_temporal_compression_ratio + 1
|
||||
latent_height = block_state.height // components.vae_spatial_compression_ratio
|
||||
latent_width = block_state.width // components.vae_spatial_compression_ratio
|
||||
|
||||
noise_pred = components.pachifier.unpack_latents(
|
||||
block_state.noise_pred, latent_num_frames, latent_height, latent_width
|
||||
)
|
||||
latents = components.pachifier.unpack_latents(
|
||||
block_state.latents, latent_num_frames, latent_height, latent_width
|
||||
)
|
||||
|
||||
noise_pred = noise_pred[:, :, 1:]
|
||||
noise_latents = latents[:, :, 1:]
|
||||
pred_latents = components.scheduler.step(noise_pred, t, noise_latents, return_dict=False)[0]
|
||||
|
||||
latents = torch.cat([latents[:, :, :1], pred_latents], dim=2)
|
||||
block_state.latents = components.pachifier.pack_latents(latents)
|
||||
|
||||
return components, block_state
|
||||
|
||||
|
||||
class LTXImage2VideoDenoiseStep(LTXDenoiseLoopWrapper):
|
||||
block_classes = [
|
||||
LTXImage2VideoLoopBeforeDenoiser,
|
||||
LTXImage2VideoLoopDenoiser(
|
||||
guider_input_fields={
|
||||
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
|
||||
"encoder_attention_mask": ("prompt_attention_mask", "negative_prompt_attention_mask"),
|
||||
}
|
||||
),
|
||||
LTXImage2VideoLoopAfterDenoiser,
|
||||
]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step for image-to-video that iteratively denoises the latents.\n"
|
||||
"The first frame is kept fixed via a conditioning mask.\n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
|
||||
" - `LTXImage2VideoLoopBeforeDenoiser`\n"
|
||||
" - `LTXImage2VideoLoopDenoiser`\n"
|
||||
" - `LTXImage2VideoLoopAfterDenoiser`"
|
||||
)
|
||||
273
src/diffusers/modular_pipelines/ltx/encoders.py
Normal file
273
src/diffusers/modular_pipelines/ltx/encoders.py
Normal file
@@ -0,0 +1,273 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from transformers import T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...guiders import ClassifierFreeGuidance
|
||||
from ...models import AutoencoderKLLTXVideo
|
||||
from ...utils import logging
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import LTXModularPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _get_t5_prompt_embeds(
|
||||
components,
|
||||
prompt: str | list[str],
|
||||
max_sequence_length: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
text_inputs = components.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
prompt_attention_mask = text_inputs.attention_mask
|
||||
prompt_attention_mask = prompt_attention_mask.bool().to(device)
|
||||
|
||||
prompt_embeds = components.text_encoder(text_input_ids.to(device))[0]
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask
|
||||
|
||||
|
||||
class LTXTextEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "ltx"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Text Encoder step that generates text embeddings to guide the video generation"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("text_encoder", T5EncoderModel),
|
||||
ComponentSpec("tokenizer", T5TokenizerFast),
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 3.0}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam.template("prompt"),
|
||||
InputParam.template("negative_prompt"),
|
||||
InputParam.template("max_sequence_length", default=128),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam.template("prompt_embeds"),
|
||||
OutputParam.template("prompt_embeds_mask", name="prompt_attention_mask"),
|
||||
OutputParam.template("negative_prompt_embeds"),
|
||||
OutputParam.template("negative_prompt_embeds_mask", name="negative_prompt_attention_mask"),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(block_state):
|
||||
if block_state.prompt is not None and (
|
||||
not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list)
|
||||
):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}")
|
||||
|
||||
@staticmethod
|
||||
def encode_prompt(
|
||||
components,
|
||||
prompt: str,
|
||||
device: torch.device | None = None,
|
||||
prepare_unconditional_embeds: bool = True,
|
||||
negative_prompt: str | None = None,
|
||||
max_sequence_length: int = 128,
|
||||
):
|
||||
device = device or components._execution_device
|
||||
dtype = components.text_encoder.dtype
|
||||
|
||||
if not isinstance(prompt, list):
|
||||
prompt = [prompt]
|
||||
batch_size = len(prompt)
|
||||
|
||||
prompt_embeds, prompt_attention_mask = _get_t5_prompt_embeds(
|
||||
components=components,
|
||||
prompt=prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
negative_prompt_embeds = None
|
||||
negative_prompt_attention_mask = None
|
||||
|
||||
if prepare_unconditional_embeds:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds, negative_prompt_attention_mask = _get_t5_prompt_embeds(
|
||||
components=components,
|
||||
prompt=negative_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: LTXModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
self.check_inputs(block_state)
|
||||
|
||||
block_state.device = components._execution_device
|
||||
|
||||
(
|
||||
block_state.prompt_embeds,
|
||||
block_state.prompt_attention_mask,
|
||||
block_state.negative_prompt_embeds,
|
||||
block_state.negative_prompt_attention_mask,
|
||||
) = self.encode_prompt(
|
||||
components=components,
|
||||
prompt=block_state.prompt,
|
||||
device=block_state.device,
|
||||
prepare_unconditional_embeds=components.requires_unconditional_embeds,
|
||||
negative_prompt=block_state.negative_prompt,
|
||||
max_sequence_length=block_state.max_sequence_length,
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
def _normalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
# Normalize latents across the channel dimension [B, C, F, H, W]
|
||||
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = (latents - latents_mean) * scaling_factor / latents_std
|
||||
return latents
|
||||
|
||||
|
||||
class LTXVaeEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "ltx"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "VAE Encoder step that encodes an input image into latent space for image-to-video generation"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKLLTXVideo),
|
||||
ComponentSpec(
|
||||
"video_processor",
|
||||
VideoProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 32}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam.template("image", required=True),
|
||||
InputParam.template("height", default=512),
|
||||
InputParam.template("width", default=704),
|
||||
InputParam.template("generator"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"image_latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="Encoded image latents from the VAE encoder",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: LTXModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
device = components._execution_device
|
||||
|
||||
image = block_state.image
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = components.video_processor.preprocess(image, height=block_state.height, width=block_state.width)
|
||||
image = image.to(device=device, dtype=torch.float32)
|
||||
|
||||
vae_dtype = components.vae.dtype
|
||||
|
||||
num_images = image.shape[0]
|
||||
if isinstance(block_state.generator, list):
|
||||
init_latents = [
|
||||
retrieve_latents(
|
||||
components.vae.encode(image[i].unsqueeze(0).unsqueeze(2).to(vae_dtype)),
|
||||
block_state.generator[i],
|
||||
)
|
||||
for i in range(num_images)
|
||||
]
|
||||
else:
|
||||
init_latents = [
|
||||
retrieve_latents(
|
||||
components.vae.encode(img.unsqueeze(0).unsqueeze(2).to(vae_dtype)),
|
||||
block_state.generator,
|
||||
)
|
||||
for img in image
|
||||
]
|
||||
|
||||
init_latents = torch.cat(init_latents, dim=0).to(torch.float32)
|
||||
block_state.image_latents = _normalize_latents(
|
||||
init_latents, components.vae.latents_mean, components.vae.latents_std
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
487
src/diffusers/modular_pipelines/ltx/modular_blocks_ltx.py
Normal file
487
src/diffusers/modular_pipelines/ltx/modular_blocks_ltx.py
Normal file
@@ -0,0 +1,487 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import OutputParam
|
||||
from .before_denoise import (
|
||||
LTXImage2VideoPrepareLatentsStep,
|
||||
LTXPrepareLatentsStep,
|
||||
LTXSetTimestepsStep,
|
||||
LTXTextInputStep,
|
||||
)
|
||||
from .decoders import LTXVaeDecoderStep
|
||||
from .denoise import LTXDenoiseStep, LTXImage2VideoDenoiseStep
|
||||
from .encoders import LTXTextEncoderStep, LTXVaeEncoderStep
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class LTXCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Denoise block that takes encoded conditions and runs the denoising process.
|
||||
|
||||
Components:
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`) pachifier (`LTXVideoPachifier`) guider
|
||||
(`ClassifierFreeGuidance`) transformer (`LTXVideoTransformer3DModel`)
|
||||
|
||||
Inputs:
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
prompt_attention_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_attention_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
timesteps (`Tensor`, *optional*):
|
||||
Timesteps for the denoising process.
|
||||
sigmas (`list`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 704):
|
||||
The width in pixels of the generated image.
|
||||
num_frames (`int`, *optional*, defaults to 161):
|
||||
TODO: Add description.
|
||||
frame_rate (`int`, *optional*, defaults to 25):
|
||||
TODO: Add description.
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "ltx"
|
||||
block_classes = [
|
||||
LTXTextInputStep,
|
||||
LTXSetTimestepsStep,
|
||||
LTXPrepareLatentsStep,
|
||||
LTXDenoiseStep,
|
||||
]
|
||||
block_names = ["input", "set_timesteps", "prepare_latents", "denoise"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Denoise block that takes encoded conditions and runs the denoising process."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam.template("latents")]
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class LTXImage2VideoCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Denoise block for image-to-video that takes encoded conditions and image latents, and runs the denoising process.
|
||||
|
||||
Components:
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`) pachifier (`LTXVideoPachifier`) guider
|
||||
(`ClassifierFreeGuidance`) transformer (`LTXVideoTransformer3DModel`)
|
||||
|
||||
Inputs:
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
prompt_attention_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_attention_mask (`Tensor`, *optional*):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
timesteps (`Tensor`, *optional*):
|
||||
Timesteps for the denoising process.
|
||||
sigmas (`list`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 704):
|
||||
The width in pixels of the generated image.
|
||||
num_frames (`int`, *optional*, defaults to 161):
|
||||
TODO: Add description.
|
||||
frame_rate (`int`, *optional*, defaults to 25):
|
||||
TODO: Add description.
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
image_latents (`Tensor`):
|
||||
TODO: Add description.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "ltx"
|
||||
block_classes = [
|
||||
LTXTextInputStep,
|
||||
LTXSetTimestepsStep,
|
||||
LTXPrepareLatentsStep,
|
||||
LTXImage2VideoPrepareLatentsStep,
|
||||
LTXImage2VideoDenoiseStep,
|
||||
]
|
||||
block_names = ["input", "set_timesteps", "prepare_latents", "prepare_i2v_latents", "denoise"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Denoise block for image-to-video that takes encoded conditions and image latents, and runs the denoising process."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam.template("latents")]
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class LTXBlocks(SequentialPipelineBlocks):
|
||||
"""
|
||||
Modular pipeline blocks for LTX Video text-to-video.
|
||||
|
||||
Components:
|
||||
text_encoder (`T5EncoderModel`) tokenizer (`T5TokenizerFast`) guider (`ClassifierFreeGuidance`) scheduler
|
||||
(`FlowMatchEulerDiscreteScheduler`) pachifier (`LTXVideoPachifier`) transformer
|
||||
(`LTXVideoTransformer3DModel`) vae (`AutoencoderKLLTXVideo`) video_processor (`VideoProcessor`)
|
||||
|
||||
Inputs:
|
||||
prompt (`str`):
|
||||
The prompt or prompts to guide image generation.
|
||||
negative_prompt (`str`, *optional*):
|
||||
The prompt or prompts not to guide the image generation.
|
||||
max_sequence_length (`int`, *optional*, defaults to 128):
|
||||
Maximum sequence length for prompt encoding.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
timesteps (`Tensor`, *optional*):
|
||||
Timesteps for the denoising process.
|
||||
sigmas (`list`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 704):
|
||||
The width in pixels of the generated image.
|
||||
num_frames (`int`, *optional*, defaults to 161):
|
||||
TODO: Add description.
|
||||
frame_rate (`int`, *optional*, defaults to 25):
|
||||
TODO: Add description.
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
output_type (`str`, *optional*, defaults to np):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
decode_timestep (`None`, *optional*, defaults to 0.0):
|
||||
TODO: Add description.
|
||||
decode_noise_scale (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
videos (`list`):
|
||||
The generated videos.
|
||||
"""
|
||||
|
||||
model_name = "ltx"
|
||||
block_classes = [
|
||||
LTXTextEncoderStep,
|
||||
LTXCoreDenoiseStep,
|
||||
LTXVaeDecoderStep,
|
||||
]
|
||||
block_names = ["text_encoder", "denoise", "decode"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Modular pipeline blocks for LTX Video text-to-video."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam.template("videos")]
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class LTXAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
"""
|
||||
VAE encoder step that encodes the image input into its latent representation.
|
||||
This is an auto pipeline block that works for image-to-video tasks.
|
||||
- `LTXVaeEncoderStep` is used when `image` is provided.
|
||||
- If `image` is not provided, step will be skipped.
|
||||
|
||||
Components:
|
||||
vae (`AutoencoderKLLTXVideo`) video_processor (`VideoProcessor`)
|
||||
|
||||
Inputs:
|
||||
image (`Image | list`, *optional*):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 704):
|
||||
The width in pixels of the generated image.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
|
||||
Outputs:
|
||||
image_latents (`Tensor`):
|
||||
Encoded image latents from the VAE encoder
|
||||
"""
|
||||
|
||||
model_name = "ltx"
|
||||
block_classes = [LTXVaeEncoderStep]
|
||||
block_names = ["vae_encoder"]
|
||||
block_trigger_inputs = ["image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"VAE encoder step that encodes the image input into its latent representation.\n"
|
||||
"This is an auto pipeline block that works for image-to-video tasks.\n"
|
||||
" - `LTXVaeEncoderStep` is used when `image` is provided.\n"
|
||||
" - If `image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class LTXAutoCoreDenoiseStep(AutoPipelineBlocks):
|
||||
"""
|
||||
Auto denoise block that selects the appropriate denoise pipeline based on inputs.
|
||||
- `LTXImage2VideoCoreDenoiseStep` is used when `image_latents` is provided.
|
||||
- `LTXCoreDenoiseStep` is used otherwise (text-to-video).
|
||||
|
||||
Components:
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`) pachifier (`LTXVideoPachifier`) guider
|
||||
(`ClassifierFreeGuidance`) transformer (`LTXVideoTransformer3DModel`)
|
||||
|
||||
Inputs:
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
prompt_attention_mask (`Tensor`):
|
||||
mask for the text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_attention_mask (`Tensor`):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
num_inference_steps (`int`):
|
||||
The number of denoising steps.
|
||||
timesteps (`Tensor`):
|
||||
Timesteps for the denoising process.
|
||||
sigmas (`list`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 704):
|
||||
The width in pixels of the generated image.
|
||||
num_frames (`int`, *optional*, defaults to 161):
|
||||
TODO: Add description.
|
||||
frame_rate (`int`, *optional*, defaults to 25):
|
||||
TODO: Add description.
|
||||
latents (`Tensor`):
|
||||
Pre-generated noisy latents for image generation.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
image_latents (`Tensor`, *optional*):
|
||||
TODO: Add description.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "ltx"
|
||||
block_classes = [LTXImage2VideoCoreDenoiseStep, LTXCoreDenoiseStep]
|
||||
block_names = ["image2video", "text2video"]
|
||||
block_trigger_inputs = ["image_latents", None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto denoise block that selects the appropriate denoise pipeline based on inputs.\n"
|
||||
" - `LTXImage2VideoCoreDenoiseStep` is used when `image_latents` is provided.\n"
|
||||
" - `LTXCoreDenoiseStep` is used otherwise (text-to-video)."
|
||||
)
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class LTXAutoBlocks(SequentialPipelineBlocks):
|
||||
"""
|
||||
Auto blocks for LTX Video that support both text-to-video and image-to-video workflows.
|
||||
|
||||
Supported workflows:
|
||||
- `text2video`: requires `prompt`
|
||||
- `image2video`: requires `image`, `prompt`
|
||||
|
||||
Components:
|
||||
text_encoder (`T5EncoderModel`) tokenizer (`T5TokenizerFast`) guider (`ClassifierFreeGuidance`) vae
|
||||
(`AutoencoderKLLTXVideo`) video_processor (`VideoProcessor`) scheduler (`FlowMatchEulerDiscreteScheduler`)
|
||||
pachifier (`LTXVideoPachifier`) transformer (`LTXVideoTransformer3DModel`)
|
||||
|
||||
Inputs:
|
||||
prompt (`str`):
|
||||
The prompt or prompts to guide image generation.
|
||||
negative_prompt (`str`, *optional*):
|
||||
The prompt or prompts not to guide the image generation.
|
||||
max_sequence_length (`int`, *optional*, defaults to 128):
|
||||
Maximum sequence length for prompt encoding.
|
||||
image (`Image | list`, *optional*):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 704):
|
||||
The width in pixels of the generated image.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
num_inference_steps (`int`):
|
||||
The number of denoising steps.
|
||||
timesteps (`Tensor`):
|
||||
Timesteps for the denoising process.
|
||||
sigmas (`list`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
num_frames (`int`, *optional*, defaults to 161):
|
||||
TODO: Add description.
|
||||
frame_rate (`int`, *optional*, defaults to 25):
|
||||
TODO: Add description.
|
||||
latents (`Tensor`):
|
||||
Pre-generated noisy latents for image generation.
|
||||
image_latents (`Tensor`, *optional*):
|
||||
TODO: Add description.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
output_type (`str`, *optional*, defaults to np):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
decode_timestep (`None`, *optional*, defaults to 0.0):
|
||||
TODO: Add description.
|
||||
decode_noise_scale (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
videos (`list`):
|
||||
The generated videos.
|
||||
"""
|
||||
|
||||
model_name = "ltx"
|
||||
block_classes = [
|
||||
LTXTextEncoderStep,
|
||||
LTXAutoVaeEncoderStep,
|
||||
LTXAutoCoreDenoiseStep,
|
||||
LTXVaeDecoderStep,
|
||||
]
|
||||
block_names = ["text_encoder", "vae_encoder", "denoise", "decode"]
|
||||
_workflow_map = {
|
||||
"text2video": {"prompt": True},
|
||||
"image2video": {"image": True, "prompt": True},
|
||||
}
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Auto blocks for LTX Video that support both text-to-video and image-to-video workflows."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam.template("videos")]
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class LTXImage2VideoBlocks(SequentialPipelineBlocks):
|
||||
"""
|
||||
Modular pipeline blocks for LTX Video image-to-video.
|
||||
|
||||
Components:
|
||||
text_encoder (`T5EncoderModel`) tokenizer (`T5TokenizerFast`) guider (`ClassifierFreeGuidance`) vae
|
||||
(`AutoencoderKLLTXVideo`) video_processor (`VideoProcessor`) scheduler (`FlowMatchEulerDiscreteScheduler`)
|
||||
pachifier (`LTXVideoPachifier`) transformer (`LTXVideoTransformer3DModel`)
|
||||
|
||||
Inputs:
|
||||
prompt (`str`):
|
||||
The prompt or prompts to guide image generation.
|
||||
negative_prompt (`str`, *optional*):
|
||||
The prompt or prompts not to guide the image generation.
|
||||
max_sequence_length (`int`, *optional*, defaults to 128):
|
||||
Maximum sequence length for prompt encoding.
|
||||
image (`Image | list`, *optional*):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 704):
|
||||
The width in pixels of the generated image.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
timesteps (`Tensor`, *optional*):
|
||||
Timesteps for the denoising process.
|
||||
sigmas (`list`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
num_frames (`int`, *optional*, defaults to 161):
|
||||
TODO: Add description.
|
||||
frame_rate (`int`, *optional*, defaults to 25):
|
||||
TODO: Add description.
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
image_latents (`Tensor`):
|
||||
TODO: Add description.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
output_type (`str`, *optional*, defaults to np):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
decode_timestep (`None`, *optional*, defaults to 0.0):
|
||||
TODO: Add description.
|
||||
decode_noise_scale (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
videos (`list`):
|
||||
The generated videos.
|
||||
"""
|
||||
|
||||
model_name = "ltx"
|
||||
block_classes = [
|
||||
LTXTextEncoderStep,
|
||||
LTXAutoVaeEncoderStep,
|
||||
LTXImage2VideoCoreDenoiseStep,
|
||||
LTXVaeDecoderStep,
|
||||
]
|
||||
block_names = ["text_encoder", "vae_encoder", "denoise", "decode"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Modular pipeline blocks for LTX Video image-to-video."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam.template("videos")]
|
||||
95
src/diffusers/modular_pipelines/ltx/modular_pipeline.py
Normal file
95
src/diffusers/modular_pipelines/ltx/modular_pipeline.py
Normal file
@@ -0,0 +1,95 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import LTXVideoLoraLoaderMixin
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import ModularPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class LTXVideoPachifier(ConfigMixin):
|
||||
"""
|
||||
A class to pack and unpack latents for LTX Video.
|
||||
"""
|
||||
|
||||
config_name = "config.json"
|
||||
|
||||
@register_to_config
|
||||
def __init__(self, patch_size: int = 1, patch_size_t: int = 1):
|
||||
super().__init__()
|
||||
|
||||
def pack_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, _, num_frames, height, width = latents.shape
|
||||
patch_size = self.config.patch_size
|
||||
patch_size_t = self.config.patch_size_t
|
||||
post_patch_num_frames = num_frames // patch_size_t
|
||||
post_patch_height = height // patch_size
|
||||
post_patch_width = width // patch_size
|
||||
latents = latents.reshape(
|
||||
batch_size,
|
||||
-1,
|
||||
post_patch_num_frames,
|
||||
patch_size_t,
|
||||
post_patch_height,
|
||||
patch_size,
|
||||
post_patch_width,
|
||||
patch_size,
|
||||
)
|
||||
latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
|
||||
return latents
|
||||
|
||||
def unpack_latents(self, latents: torch.Tensor, num_frames: int, height: int, width: int) -> torch.Tensor:
|
||||
batch_size = latents.size(0)
|
||||
patch_size = self.config.patch_size
|
||||
patch_size_t = self.config.patch_size_t
|
||||
latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
|
||||
latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
return latents
|
||||
|
||||
|
||||
class LTXModularPipeline(
|
||||
ModularPipeline,
|
||||
LTXVideoLoraLoaderMixin,
|
||||
):
|
||||
"""
|
||||
A ModularPipeline for LTX Video.
|
||||
|
||||
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||
"""
|
||||
|
||||
default_blocks_name = "LTXAutoBlocks"
|
||||
|
||||
@property
|
||||
def vae_spatial_compression_ratio(self):
|
||||
if getattr(self, "vae", None) is not None:
|
||||
return self.vae.spatial_compression_ratio
|
||||
return 32
|
||||
|
||||
@property
|
||||
def vae_temporal_compression_ratio(self):
|
||||
if getattr(self, "vae", None) is not None:
|
||||
return self.vae.temporal_compression_ratio
|
||||
return 8
|
||||
|
||||
@property
|
||||
def requires_unconditional_embeds(self):
|
||||
if hasattr(self, "guider") and self.guider is not None:
|
||||
return self.guider._enabled and self.guider.num_conditions > 1
|
||||
return False
|
||||
@@ -132,6 +132,7 @@ MODULAR_PIPELINE_MAPPING = OrderedDict(
|
||||
("z-image", _create_default_map_fn("ZImageModularPipeline")),
|
||||
("helios", _create_default_map_fn("HeliosModularPipeline")),
|
||||
("helios-pyramid", _helios_pyramid_map_fn),
|
||||
("ltx", _create_default_map_fn("LTXModularPipeline")),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -335,6 +335,7 @@ else:
|
||||
)
|
||||
_import_structure["mochi"] = ["MochiPipeline"]
|
||||
_import_structure["omnigen"] = ["OmniGenPipeline"]
|
||||
_import_structure["ernie_image"] = ["ErnieImagePipeline"]
|
||||
_import_structure["ovis_image"] = ["OvisImagePipeline"]
|
||||
_import_structure["visualcloze"] = ["VisualClozePipeline", "VisualClozeGenerationPipeline"]
|
||||
_import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"]
|
||||
@@ -678,6 +679,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
EasyAnimateInpaintPipeline,
|
||||
EasyAnimatePipeline,
|
||||
)
|
||||
from .ernie_image import ErnieImagePipeline
|
||||
from .flux import (
|
||||
FluxControlImg2ImgPipeline,
|
||||
FluxControlInpaintPipeline,
|
||||
|
||||
@@ -5,10 +5,13 @@ import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image, ImageOps
|
||||
from torchvision.transforms import InterpolationMode
|
||||
from torchvision.transforms.functional import normalize, resize
|
||||
|
||||
from ...utils import get_logger, load_image
|
||||
from ...utils import get_logger, is_torchvision_available, load_image
|
||||
|
||||
|
||||
if is_torchvision_available():
|
||||
from torchvision.transforms import InterpolationMode
|
||||
from torchvision.transforms.functional import normalize, resize
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
47
src/diffusers/pipelines/ernie_image/__init__.py
Normal file
47
src/diffusers/pipelines/ernie_image/__init__.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa: F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_ernie_image"] = ["ErnieImagePipeline"]
|
||||
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_ernie_image import ErnieImagePipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
389
src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py
Normal file
389
src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py
Normal file
@@ -0,0 +1,389 @@
|
||||
# Copyright 2025 Baidu ERNIE-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Ernie-Image Pipeline for HuggingFace Diffusers.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from ...models import AutoencoderKLFlux2
|
||||
from ...models.transformers import ErnieImageTransformer2DModel
|
||||
from ...pipelines.pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from .pipeline_output import ErnieImagePipelineOutput
|
||||
|
||||
|
||||
class ErnieImagePipeline(DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for text-to-image generation using ErnieImageTransformer2DModel.
|
||||
|
||||
This pipeline uses:
|
||||
- A custom DiT transformer model
|
||||
- A Flux2-style VAE for encoding/decoding latents
|
||||
- A text encoder (e.g., Qwen) for text conditioning
|
||||
- Flow Matching Euler Discrete Scheduler
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "pe->text_encoder->transformer->vae"
|
||||
# For SGLang fallback ...
|
||||
_optional_components = ["pe", "pe_tokenizer"]
|
||||
_callback_tensor_inputs = ["latents"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformer: ErnieImageTransformer2DModel,
|
||||
vae: AutoencoderKLFlux2,
|
||||
text_encoder: AutoModel,
|
||||
tokenizer: AutoTokenizer,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
pe: Optional[AutoModelForCausalLM] = None,
|
||||
pe_tokenizer: Optional[AutoTokenizer] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.register_modules(
|
||||
transformer=transformer,
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
pe=pe,
|
||||
pe_tokenizer=pe_tokenizer,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels)) if getattr(self, "vae", None) else 16
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0
|
||||
|
||||
@torch.no_grad()
|
||||
def _enhance_prompt_with_pe(
|
||||
self,
|
||||
prompt: str,
|
||||
device: torch.device,
|
||||
width: int = 1024,
|
||||
height: int = 1024,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.95,
|
||||
) -> str:
|
||||
"""Use PE model to rewrite/enhance a short prompt via chat_template."""
|
||||
# Build user message as JSON carrying prompt text and target resolution
|
||||
user_content = json.dumps(
|
||||
{"prompt": prompt, "width": width, "height": height},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
messages = []
|
||||
if system_prompt is not None:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": user_content})
|
||||
|
||||
# apply_chat_template picks up the chat_template.jinja loaded with pe_tokenizer
|
||||
input_text = self.pe_tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=False, # "Output:" is already in the user block
|
||||
)
|
||||
inputs = self.pe_tokenizer(input_text, return_tensors="pt").to(device)
|
||||
output_ids = self.pe.generate(
|
||||
**inputs,
|
||||
max_new_tokens=self.pe_tokenizer.model_max_length,
|
||||
do_sample=temperature != 1.0 or top_p != 1.0,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
pad_token_id=self.pe_tokenizer.pad_token_id,
|
||||
eos_token_id=self.pe_tokenizer.eos_token_id,
|
||||
)
|
||||
# Decode only newly generated tokens
|
||||
generated_ids = output_ids[0][inputs["input_ids"].shape[1] :]
|
||||
return self.pe_tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: torch.device,
|
||||
num_images_per_prompt: int = 1,
|
||||
) -> List[torch.Tensor]:
|
||||
"""Encode text prompts to embeddings."""
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
|
||||
text_hiddens = []
|
||||
|
||||
for p in prompt:
|
||||
ids = self.tokenizer(
|
||||
p,
|
||||
add_special_tokens=True,
|
||||
truncation=True,
|
||||
padding=False,
|
||||
)["input_ids"]
|
||||
|
||||
if len(ids) == 0:
|
||||
if self.tokenizer.bos_token_id is not None:
|
||||
ids = [self.tokenizer.bos_token_id]
|
||||
else:
|
||||
ids = [0]
|
||||
|
||||
input_ids = torch.tensor([ids], device=device)
|
||||
with torch.no_grad():
|
||||
outputs = self.text_encoder(
|
||||
input_ids=input_ids,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
# Use second to last hidden state (matches training)
|
||||
hidden = outputs.hidden_states[-2][0] # [T, H]
|
||||
|
||||
# Repeat for num_images_per_prompt
|
||||
for _ in range(num_images_per_prompt):
|
||||
text_hiddens.append(hidden)
|
||||
|
||||
return text_hiddens
|
||||
|
||||
@staticmethod
|
||||
def _patchify_latents(latents: torch.Tensor) -> torch.Tensor:
|
||||
"""2x2 patchify: [B, 32, H, W] -> [B, 128, H/2, W/2]"""
|
||||
b, c, h, w = latents.shape
|
||||
latents = latents.view(b, c, h // 2, 2, w // 2, 2)
|
||||
latents = latents.permute(0, 1, 3, 5, 2, 4)
|
||||
return latents.reshape(b, c * 4, h // 2, w // 2)
|
||||
|
||||
@staticmethod
|
||||
def _unpatchify_latents(latents: torch.Tensor) -> torch.Tensor:
|
||||
"""Reverse patchify: [B, 128, H/2, W/2] -> [B, 32, H, W]"""
|
||||
b, c, h, w = latents.shape
|
||||
latents = latents.reshape(b, c // 4, 2, 2, h, w)
|
||||
latents = latents.permute(0, 1, 4, 2, 5, 3)
|
||||
return latents.reshape(b, c // 4, h * 2, w * 2)
|
||||
|
||||
@staticmethod
|
||||
def _pad_text(text_hiddens: List[torch.Tensor], device: torch.device, dtype: torch.dtype, text_in_dim: int):
|
||||
B = len(text_hiddens)
|
||||
if B == 0:
|
||||
return torch.zeros((0, 0, text_in_dim), device=device, dtype=dtype), torch.zeros(
|
||||
(0,), device=device, dtype=torch.long
|
||||
)
|
||||
normalized = [
|
||||
th.squeeze(1).to(device).to(dtype) if th.dim() == 3 else th.to(device).to(dtype) for th in text_hiddens
|
||||
]
|
||||
lens = torch.tensor([t.shape[0] for t in normalized], device=device, dtype=torch.long)
|
||||
Tmax = int(lens.max().item())
|
||||
text_bth = torch.zeros((B, Tmax, text_in_dim), device=device, dtype=dtype)
|
||||
for i, t in enumerate(normalized):
|
||||
text_bth[i, : t.shape[0], :] = t
|
||||
return text_bth, lens
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = "",
|
||||
height: int = 1024,
|
||||
width: int = 1024,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 4.0,
|
||||
num_images_per_prompt: int = 1,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: list[torch.FloatTensor] | None = None,
|
||||
negative_prompt_embeds: list[torch.FloatTensor] | None = None,
|
||||
output_type: str = "pil",
|
||||
return_dict: bool = True,
|
||||
callback_on_step_end: Optional[Callable[[int, int, dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
use_pe: bool = True, # 默认使用PE进行改写
|
||||
):
|
||||
"""
|
||||
Generate images from text prompts.
|
||||
|
||||
Args:
|
||||
prompt: Text prompt(s)
|
||||
negative_prompt: Negative prompt(s) for CFG. Default is "".
|
||||
height: Image height in pixels (must be divisible by 16). Default: 1024.
|
||||
width: Image width in pixels (must be divisible by 16). Default: 1024.
|
||||
num_inference_steps: Number of denoising steps
|
||||
guidance_scale: CFG scale (1.0 = no guidance). Default: 4.0.
|
||||
num_images_per_prompt: Number of images per prompt
|
||||
generator: Random generator for reproducibility
|
||||
latents: Pre-generated latents (optional)
|
||||
prompt_embeds: Pre-computed text embeddings for positive prompts (optional).
|
||||
If provided, `encode_prompt` is skipped for positive prompts.
|
||||
negative_prompt_embeds: Pre-computed text embeddings for negative prompts (optional).
|
||||
If provided, `encode_prompt` is skipped for negative prompts.
|
||||
output_type: "pil" or "latent"
|
||||
return_dict: Whether to return a dataclass
|
||||
callback_on_step_end: Optional callback invoked at the end of each denoising step.
|
||||
Called as `callback_on_step_end(pipeline, step, timestep, callback_kwargs)` where `callback_kwargs`
|
||||
contains the tensors listed in `callback_on_step_end_tensor_inputs`. The callback may return a dict to
|
||||
override those tensors for subsequent steps.
|
||||
callback_on_step_end_tensor_inputs: List of tensor names passed into the callback kwargs.
|
||||
Must be a subset of `_callback_tensor_inputs` (default: `["latents"]`).
|
||||
use_pe: Whether to use the PE model to enhance prompts before generation.
|
||||
|
||||
Returns:
|
||||
:class:`ErnieImagePipelineOutput` with `images` and `revised_prompts`.
|
||||
"""
|
||||
device = self._execution_device
|
||||
dtype = self.transformer.dtype
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
|
||||
# Validate prompt / prompt_embeds
|
||||
if prompt is None and prompt_embeds is None:
|
||||
raise ValueError("Must provide either `prompt` or `prompt_embeds`.")
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError("Cannot provide both `prompt` and `prompt_embeds` at the same time.")
|
||||
|
||||
# Validate dimensions
|
||||
if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
|
||||
raise ValueError(f"Height and width must be divisible by {self.vae_scale_factor}")
|
||||
|
||||
# Handle prompts
|
||||
if prompt is not None:
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
|
||||
# [Phase 1] PE: enhance prompts
|
||||
revised_prompts: Optional[List[str]] = None
|
||||
if prompt is not None and use_pe and self.pe is not None and self.pe_tokenizer is not None:
|
||||
prompt = [self._enhance_prompt_with_pe(p, device, width=width, height=height) for p in prompt]
|
||||
revised_prompts = list(prompt)
|
||||
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = len(prompt_embeds)
|
||||
total_batch_size = batch_size * num_images_per_prompt
|
||||
|
||||
# Handle negative prompt
|
||||
if negative_prompt is None:
|
||||
negative_prompt = ""
|
||||
if isinstance(negative_prompt, str):
|
||||
negative_prompt = [negative_prompt] * batch_size
|
||||
if len(negative_prompt) != batch_size:
|
||||
raise ValueError(f"negative_prompt must have same length as prompt ({batch_size})")
|
||||
|
||||
# [Phase 2] Text encoding
|
||||
if prompt_embeds is not None:
|
||||
text_hiddens = prompt_embeds
|
||||
else:
|
||||
text_hiddens = self.encode_prompt(prompt, device, num_images_per_prompt)
|
||||
|
||||
# CFG with negative prompt
|
||||
if self.do_classifier_free_guidance:
|
||||
if negative_prompt_embeds is not None:
|
||||
uncond_text_hiddens = negative_prompt_embeds
|
||||
else:
|
||||
uncond_text_hiddens = self.encode_prompt(negative_prompt, device, num_images_per_prompt)
|
||||
|
||||
# Latent dimensions
|
||||
latent_h = height // self.vae_scale_factor
|
||||
latent_w = width // self.vae_scale_factor
|
||||
latent_channels = self.transformer.config.in_channels # After patchify
|
||||
|
||||
# Initialize latents
|
||||
if latents is None:
|
||||
latents = randn_tensor(
|
||||
(total_batch_size, latent_channels, latent_h, latent_w),
|
||||
generator=generator,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Setup scheduler
|
||||
sigmas = torch.linspace(1.0, 0.0, num_inference_steps + 1)
|
||||
self.scheduler.set_timesteps(sigmas=sigmas[:-1], device=device)
|
||||
|
||||
# Denoising loop
|
||||
if self.do_classifier_free_guidance:
|
||||
cfg_text_hiddens = list(uncond_text_hiddens) + list(text_hiddens)
|
||||
else:
|
||||
cfg_text_hiddens = text_hiddens
|
||||
text_bth, text_lens = self._pad_text(
|
||||
text_hiddens=cfg_text_hiddens, device=device, dtype=dtype, text_in_dim=self.transformer.config.text_in_dim
|
||||
)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(self.scheduler.timesteps):
|
||||
if self.do_classifier_free_guidance:
|
||||
latent_model_input = torch.cat([latents, latents], dim=0)
|
||||
t_batch = torch.full((total_batch_size * 2,), t.item(), device=device, dtype=dtype)
|
||||
else:
|
||||
latent_model_input = latents
|
||||
t_batch = torch.full((total_batch_size,), t.item(), device=device, dtype=dtype)
|
||||
|
||||
# Model prediction
|
||||
pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=t_batch,
|
||||
text_bth=text_bth,
|
||||
text_lens=text_lens,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# Apply CFG
|
||||
if self.do_classifier_free_guidance:
|
||||
pred_uncond, pred_cond = pred.chunk(2, dim=0)
|
||||
pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
|
||||
|
||||
# Scheduler step
|
||||
latents = self.scheduler.step(pred, t, latents).prev_sample
|
||||
|
||||
# Callback
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {k: locals()[k] for k in callback_on_step_end_tensor_inputs}
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
|
||||
progress_bar.update()
|
||||
|
||||
if output_type == "latent":
|
||||
return latents
|
||||
|
||||
# Decode latents to images
|
||||
# Unnormalize latents using VAE's BN stats
|
||||
bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(device)
|
||||
bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + 1e-5).to(device)
|
||||
latents = latents * bn_std + bn_mean
|
||||
|
||||
# Unpatchify
|
||||
latents = self._unpatchify_latents(latents)
|
||||
|
||||
# Decode
|
||||
images = self.vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
# Post-process
|
||||
images = (images.clamp(-1, 1) + 1) / 2
|
||||
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
if output_type == "pil":
|
||||
images = [Image.fromarray((img * 255).astype("uint8")) for img in images]
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (images,)
|
||||
|
||||
return ErnieImagePipelineOutput(images=images, revised_prompts=revised_prompts)
|
||||
36
src/diffusers/pipelines/ernie_image/pipeline_output.py
Normal file
36
src/diffusers/pipelines/ernie_image/pipeline_output.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# Copyright 2025 Baidu ERNIE-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
import PIL.Image
|
||||
|
||||
from ...utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class ErnieImagePipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for ERNIE-Image pipelines.
|
||||
|
||||
Args:
|
||||
images (`List[PIL.Image.Image]`):
|
||||
List of generated images.
|
||||
revised_prompts (`List[str]`, *optional*):
|
||||
List of PE-revised prompts. `None` when PE is disabled or unavailable.
|
||||
"""
|
||||
|
||||
images: List[PIL.Image.Image]
|
||||
revised_prompts: Optional[List[str]]
|
||||
@@ -611,7 +611,7 @@ class HunyuanVideo15ImageToVideoPipeline(DiffusionPipeline):
|
||||
tuple: (cond_latents_concat, mask_concat) - both are zero tensors for t2v
|
||||
"""
|
||||
|
||||
batch, channels, frames, height, width = latents.shape
|
||||
batch, channels, frames, latent_height, latent_width = latents.shape
|
||||
|
||||
image_latents = self._get_image_latents(
|
||||
vae=self.vae,
|
||||
@@ -626,7 +626,7 @@ class HunyuanVideo15ImageToVideoPipeline(DiffusionPipeline):
|
||||
latent_condition[:, :, 1:, :, :] = 0
|
||||
latent_condition = latent_condition.to(device=device, dtype=dtype)
|
||||
|
||||
latent_mask = torch.zeros(batch, 1, frames, height, width, dtype=dtype, device=device)
|
||||
latent_mask = torch.zeros(batch, 1, frames, latent_height, latent_width, dtype=dtype, device=device)
|
||||
latent_mask[:, :, 0, :, :] = 1.0
|
||||
|
||||
return latent_condition, latent_mask
|
||||
|
||||
@@ -1110,6 +1110,21 @@ class EasyAnimateTransformer3DModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class ErnieImageTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class Flux2Transformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -242,6 +242,36 @@ class HeliosPyramidModularPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LTXAutoBlocks(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LTXModularPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class QwenImageAutoBlocks(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
@@ -1202,6 +1232,21 @@ class EasyAnimatePipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class ErnieImagePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Flux2KleinKVPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
132
tests/models/transformers/test_models_transformer_ernie_image.py
Normal file
132
tests/models/transformers/test_models_transformer_ernie_image.py
Normal file
@@ -0,0 +1,132 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers import ErnieImageTransformer2DModel
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import torch_device
|
||||
from ..testing_utils import (
|
||||
BaseModelTesterConfig,
|
||||
ModelTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
# Ernie-Image requires torch.use_deterministic_algorithms(False) due to complex64 RoPE operations.
|
||||
# Cannot use enable_full_determinism() which sets it to True.
|
||||
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
|
||||
torch.use_deterministic_algorithms(False)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
if hasattr(torch.backends, "cuda"):
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class ErnieImageTransformerTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return ErnieImageTransformer2DModel
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple:
|
||||
return (16, 16, 16)
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple:
|
||||
return (16, 16, 16)
|
||||
|
||||
@property
|
||||
def model_split_percents(self) -> list:
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
return [0.9, 0.9, 0.9]
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict:
|
||||
return {
|
||||
"hidden_size": 16,
|
||||
"num_attention_heads": 1,
|
||||
"num_layers": 1,
|
||||
"ffn_hidden_size": 16,
|
||||
"in_channels": 16,
|
||||
"out_channels": 16,
|
||||
"patch_size": 1,
|
||||
"text_in_dim": 16,
|
||||
"rope_theta": 256,
|
||||
"rope_axes_dim": (8, 4, 4),
|
||||
"eps": 1e-6,
|
||||
"qk_layernorm": True,
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, height: int = 16, width: int = 16, batch_size: int = 1) -> dict:
|
||||
num_channels = 16 # in_channels
|
||||
sequence_length = 16
|
||||
text_in_dim = 16 # text_in_dim
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, num_channels, height, width), generator=self.generator, device=torch_device
|
||||
),
|
||||
"timestep": torch.tensor([1.0] * batch_size, device=torch_device),
|
||||
"text_bth": randn_tensor(
|
||||
(batch_size, sequence_length, text_in_dim), generator=self.generator, device=torch_device
|
||||
),
|
||||
"text_lens": torch.tensor([sequence_length] * batch_size, device=torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestErnieImageTransformer(ErnieImageTransformerTesterConfig, ModelTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
class TestErnieImageTransformerTraining(ErnieImageTransformerTesterConfig, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"ErnieImageTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestErnieImageTransformerCompile(ErnieImageTransformerTesterConfig, TorchCompileTesterMixin):
|
||||
@property
|
||||
def different_shapes_for_compilation(self):
|
||||
return [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="The repeated block in this model is ErnieImageSharedAdaLNBlock. As a consequence of this, "
|
||||
"the inputs recorded for the block would vary during compilation and full compilation with "
|
||||
"fullgraph=True would trigger recompilation."
|
||||
)
|
||||
def test_torch_compile_recompilation_and_graph_break(self):
|
||||
super().test_torch_compile_recompilation_and_graph_break()
|
||||
|
||||
@pytest.mark.skip(reason="Fullgraph AoT is broken.")
|
||||
def test_compile_works_with_aot(self, tmp_path):
|
||||
super().test_compile_works_with_aot(tmp_path)
|
||||
|
||||
@pytest.mark.skip(reason="Fullgraph is broken.")
|
||||
def test_compile_on_different_shapes(self):
|
||||
super().test_compile_on_different_shapes()
|
||||
0
tests/modular_pipelines/ltx/__init__.py
Normal file
0
tests/modular_pipelines/ltx/__init__.py
Normal file
72
tests/modular_pipelines/ltx/test_modular_pipeline_ltx.py
Normal file
72
tests/modular_pipelines/ltx/test_modular_pipeline_ltx.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
|
||||
from diffusers.modular_pipelines import LTXAutoBlocks, LTXModularPipeline
|
||||
|
||||
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
|
||||
|
||||
|
||||
LTX_WORKFLOWS = {
|
||||
"text2video": [
|
||||
("text_encoder", "LTXTextEncoderStep"),
|
||||
("denoise.input", "LTXTextInputStep"),
|
||||
("denoise.set_timesteps", "LTXSetTimestepsStep"),
|
||||
("denoise.prepare_latents", "LTXPrepareLatentsStep"),
|
||||
("denoise.denoise", "LTXDenoiseStep"),
|
||||
("decode", "LTXVaeDecoderStep"),
|
||||
],
|
||||
"image2video": [
|
||||
("text_encoder", "LTXTextEncoderStep"),
|
||||
("vae_encoder", "LTXVaeEncoderStep"),
|
||||
("denoise.input", "LTXTextInputStep"),
|
||||
("denoise.set_timesteps", "LTXSetTimestepsStep"),
|
||||
("denoise.prepare_latents", "LTXPrepareLatentsStep"),
|
||||
("denoise.prepare_i2v_latents", "LTXImage2VideoPrepareLatentsStep"),
|
||||
("denoise.denoise", "LTXImage2VideoDenoiseStep"),
|
||||
("decode", "LTXVaeDecoderStep"),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class TestLTXModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = LTXModularPipeline
|
||||
pipeline_blocks_class = LTXAutoBlocks
|
||||
pretrained_model_name_or_path = "akshan-main/tiny-ltx-modular-pipe"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "num_frames"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
optional_params = frozenset(["num_inference_steps", "num_videos_per_prompt", "latents"])
|
||||
expected_workflow_blocks = LTX_WORKFLOWS
|
||||
output_name = "videos"
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"num_frames": 9,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
return inputs
|
||||
|
||||
@pytest.mark.skip(reason="num_videos_per_prompt")
|
||||
def test_num_images_per_prompt(self):
|
||||
pass
|
||||
@@ -13,16 +13,14 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import unittest
|
||||
from importlib import import_module
|
||||
|
||||
import pytest
|
||||
|
||||
class DependencyTester(unittest.TestCase):
|
||||
|
||||
class TestDependencies:
|
||||
def test_diffusers_import(self):
|
||||
try:
|
||||
import diffusers # noqa: F401
|
||||
except ImportError:
|
||||
assert False
|
||||
import diffusers # noqa: F401
|
||||
|
||||
def test_backend_registration(self):
|
||||
import diffusers
|
||||
@@ -52,3 +50,36 @@ class DependencyTester(unittest.TestCase):
|
||||
if hasattr(diffusers.pipelines, cls_name):
|
||||
pipeline_folder_module = ".".join(str(cls_module.__module__).split(".")[:3])
|
||||
_ = import_module(pipeline_folder_module, str(cls_name))
|
||||
|
||||
def test_pipeline_module_imports(self):
|
||||
"""Import every pipeline submodule whose dependencies are satisfied,
|
||||
to catch unguarded optional-dep imports (e.g., torchvision).
|
||||
|
||||
Uses inspect.getmembers to discover classes that the lazy loader can
|
||||
actually resolve (same self-filtering as test_pipeline_imports), then
|
||||
imports the full module path instead of truncating to the folder level.
|
||||
"""
|
||||
import diffusers
|
||||
import diffusers.pipelines
|
||||
|
||||
failures = []
|
||||
all_classes = inspect.getmembers(diffusers, inspect.isclass)
|
||||
|
||||
for cls_name, cls_module in all_classes:
|
||||
if not hasattr(diffusers.pipelines, cls_name):
|
||||
continue
|
||||
if "dummy_" in cls_module.__module__:
|
||||
continue
|
||||
|
||||
full_module_path = cls_module.__module__
|
||||
try:
|
||||
import_module(full_module_path)
|
||||
except ImportError as e:
|
||||
failures.append(f"{full_module_path}: {e}")
|
||||
except Exception:
|
||||
# Non-import errors (e.g., missing config) are fine; we only
|
||||
# care about unguarded import statements.
|
||||
pass
|
||||
|
||||
if failures:
|
||||
pytest.fail("Unguarded optional-dependency imports found:\n" + "\n".join(failures))
|
||||
|
||||
Reference in New Issue
Block a user