Compare commits

..

6 Commits

Author SHA1 Message Date
dg845
7298f5be93 Update LTX-2 Docs to Cover LTX-2.3 Models (#13337)
* Update LTX-2 docs to cover multimodal guidance and prompt enhancement

* Apply suggestions from code review

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Apply reviewer feedback

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2026-03-26 17:51:29 -07:00
Sayak Paul
b757035df6 fix claude workflow to include id-token with write. (#13338) 2026-03-26 15:39:10 +05:30
kaixuanliu
41e1003316 avoid hardcode device in flux-control example (#13336)
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
2026-03-26 12:40:53 +05:30
Sayak Paul
85ffcf1db2 [tests] Tests for conditional pipeline blocks (#13247)
* implement test suite for conditional blocks.

* remove

* another fix.

* Revert "another fix."

This reverts commit ab07b603ab.
2026-03-26 08:48:16 +05:30
Steven Liu
cbf4d9a3c3 [docs] kernels (#13139)
* kernels

* feedback
2026-03-25 09:31:54 -07:00
Sayak Paul
426daabad9 [ci] claude in ci. (#13297)
* claude in ci.

* review feedback.
2026-03-25 21:30:06 +05:30
14 changed files with 1094 additions and 535 deletions

11
.ai/review-rules.md Normal file
View File

@@ -0,0 +1,11 @@
# PR Review Rules
Review-specific rules for Claude. Focus on correctness — style is handled by ruff.
Before reviewing, read and apply the guidelines in:
- [AGENTS.md](AGENTS.md) — coding style, dependencies, copied code, model conventions
- [skills/model-integration/SKILL.md](skills/model-integration/SKILL.md) — attention pattern, pipeline rules, implementation checklist, gotchas
- [skills/parity-testing/SKILL.md](skills/parity-testing/SKILL.md) — testing rules, comparison utilities
- [skills/parity-testing/pitfalls.md](skills/parity-testing/pitfalls.md) — known pitfalls (dtype mismatches, config assumptions, etc.)
## Common mistakes (add new rules below this line)

39
.github/workflows/claude_review.yml vendored Normal file
View File

@@ -0,0 +1,39 @@
name: Claude PR Review
on:
issue_comment:
types: [created]
pull_request_review_comment:
types: [created]
permissions:
contents: write
pull-requests: write
issues: read
id-token: write
jobs:
claude-review:
if: |
(
github.event_name == 'issue_comment' &&
github.event.issue.pull_request &&
github.event.issue.state == 'open' &&
contains(github.event.comment.body, '@claude') &&
(github.event.comment.author_association == 'MEMBER' ||
github.event.comment.author_association == 'OWNER' ||
github.event.comment.author_association == 'COLLABORATOR')
) || (
github.event_name == 'pull_request_review_comment' &&
contains(github.event.comment.body, '@claude') &&
(github.event.comment.author_association == 'MEMBER' ||
github.event.comment.author_association == 'OWNER' ||
github.event.comment.author_association == 'COLLABORATOR')
)
runs-on: ubuntu-latest
steps:
- uses: anthropics/claude-code-action@v1
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
claude_args: |
--append-system-prompt "Review this PR against the rules in .ai/review-rules.md. Focus on correctness, not style (ruff handles style). Only review changes under src/diffusers/. Do NOT commit changes unless the comment explicitly asks you to using the phrase 'commit this'."

View File

@@ -18,7 +18,7 @@
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
</div>
LTX-2 is a DiT-based audio-video foundation model designed to generate synchronized video and audio within a single model. It brings together the core building blocks of modern video generation, with open weights and a focus on practical, local execution.
[LTX-2](https://hf.co/papers/2601.03233) is a DiT-based foundation model designed to generate synchronized video and audio within a single model. It brings together the core building blocks of modern video generation, with open weights and a focus on practical, local execution.
You can find all the original LTX-Video checkpoints under the [Lightricks](https://huggingface.co/Lightricks) organization.
@@ -293,6 +293,7 @@ import torch
from diffusers import LTX2ConditionPipeline
from diffusers.pipelines.ltx2.pipeline_ltx2_condition import LTX2VideoCondition
from diffusers.pipelines.ltx2.export_utils import encode_video
from diffusers.pipelines.ltx2.utils import DEFAULT_NEGATIVE_PROMPT
from diffusers.utils import load_image, load_video
device = "cuda"
@@ -315,19 +316,6 @@ prompt = (
"landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the "
"solitude and beauty of a winter drive through a mountainous region."
)
negative_prompt = (
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
)
cond_video = load_video(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input-vid.mp4"
@@ -343,7 +331,7 @@ frame_rate = 24.0
video, audio = pipe(
conditions=conditions,
prompt=prompt,
negative_prompt=negative_prompt,
negative_prompt=DEFAULT_NEGATIVE_PROMPT,
width=width,
height=height,
num_frames=121,
@@ -366,6 +354,154 @@ encode_video(
Because the conditioning is done via latent frames, the 8 data space frames corresponding to the specified latent frame for an image condition will tend to be static.
## Multimodal Guidance
LTX-2.X pipelines support multimodal guidance. It is composed of three terms, all using a CFG-style update rule:
1. Classifier-Free Guidance (CFG): standard [CFG](https://huggingface.co/papers/2207.12598) where the perturbed ("weaker") output is generated using the negative prompt.
2. Spatio-Temporal Guidance (STG): [STG](https://huggingface.co/papers/2411.18664) moves away from a perturbed output created from short-cutting self-attention operations and substitutes in the attention values instead. The idea is that this creates sharper videos and better spatiotemporal consistency.
3. Modality Isolation Guidance: moves away from a perturbed output created from disabling cross-modality (audio-to-video and video-to-audio) cross attention. This guidance is more specific to [LTX-2.X](https://huggingface.co/papers/2601.03233) models, with the idea that this produces better consistency between the generated audio and video.
These are controlled by the `guidance_scale`, `stg_scale`, and `modality_scale` arguments and can be set separately for video and audio. Additionally, for STG the transformer block indices where self-attention is skipped needs to be specified via the `spatio_temporal_guidance_blocks` argument. The LTX-2.X pipelines also support [guidance rescaling](https://huggingface.co/papers/2305.08891) to help reduce over-exposure, which can be a problem when the guidance scales are set to high values.
```py
import torch
from diffusers import LTX2ImageToVideoPipeline
from diffusers.pipelines.ltx2.export_utils import encode_video
from diffusers.pipelines.ltx2.utils import DEFAULT_NEGATIVE_PROMPT
from diffusers.utils import load_image
device = "cuda"
width = 768
height = 512
random_seed = 42
frame_rate = 24.0
generator = torch.Generator(device).manual_seed(random_seed)
model_path = "dg845/LTX-2.3-Diffusers"
pipe = LTX2ImageToVideoPipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16)
pipe.enable_sequential_cpu_offload(device=device)
pipe.vae.enable_tiling()
prompt = (
"An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in "
"gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in slow arcs "
"before settling back onto the ground. The astronaut pushes free in a deliberate, weightless motion, small "
"fragments of the egg tumbling and spinning through the air. In the background, the deep darkness of space subtly "
"shifts as stars glide with the camera's movement, emphasizing vast depth and scale. The camera performs a "
"smooth, cinematic slow push-in, with natural parallax between the foreground dust, the astronaut, and the "
"distant starfield. Ultra-realistic detail, physically accurate low-gravity motion, cinematic lighting, and a "
"breath-taking, movie-like shot."
)
image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg",
)
video, audio = pipe(
image=image,
prompt=prompt,
negative_prompt=DEFAULT_NEGATIVE_PROMPT,
width=width,
height=height,
num_frames=121,
frame_rate=frame_rate,
num_inference_steps=30,
guidance_scale=3.0, # Recommended LTX-2.3 guidance parameters
stg_scale=1.0, # Note that 0.0 (not 1.0) means that STG is disabled (all other guidance is disabled at 1.0)
modality_scale=3.0,
guidance_rescale=0.7,
audio_guidance_scale=7.0, # Note that a higher CFG guidance scale is recommended for audio
audio_stg_scale=1.0,
audio_modality_scale=3.0,
audio_guidance_rescale=0.7,
spatio_temporal_guidance_blocks=[28],
use_cross_timestep=True,
generator=generator,
output_type="np",
return_dict=False,
)
encode_video(
video[0],
fps=frame_rate,
audio=audio[0].float().cpu(),
audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
output_path="ltx2_3_i2v_stage_1.mp4",
)
```
## Prompt Enhancement
The LTX-2.X models are sensitive to prompting style. Refer to the [official prompting guide](https://ltx.io/model/model-blog/prompting-guide-for-ltx-2) for recommendations on how to write a good prompt. Using prompt enhancement, where the supplied prompts are enhanced using the pipeline's text encoder (by default a [Gemma 3](https://huggingface.co/google/gemma-3-12b-it-qat-q4_0-unquantized) model) given a system prompt, can also improve sample quality. The optional `processor` pipeline component needs to be present to use prompt enhancement. Enable prompt enhancement by supplying a `system_prompt` argument:
```py
import torch
from transformers import Gemma3Processor
from diffusers import LTX2Pipeline
from diffusers.pipelines.ltx2.export_utils import encode_video
from diffusers.pipelines.ltx2.utils import DEFAULT_NEGATIVE_PROMPT, T2V_DEFAULT_SYSTEM_PROMPT
device = "cuda"
width = 768
height = 512
random_seed = 42
frame_rate = 24.0
generator = torch.Generator(device).manual_seed(random_seed)
model_path = "dg845/LTX-2.3-Diffusers"
pipe = LTX2Pipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload(device=device)
pipe.vae.enable_tiling()
if getattr(pipe, "processor", None) is None:
processor = Gemma3Processor.from_pretrained("google/gemma-3-12b-it-qat-q4_0-unquantized")
pipe.processor = processor
prompt = (
"An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in "
"gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in slow arcs "
"before settling back onto the ground. The astronaut pushes free in a deliberate, weightless motion, small "
"fragments of the egg tumbling and spinning through the air. In the background, the deep darkness of space subtly "
"shifts as stars glide with the camera's movement, emphasizing vast depth and scale. The camera performs a "
"smooth, cinematic slow push-in, with natural parallax between the foreground dust, the astronaut, and the "
"distant starfield. Ultra-realistic detail, physically accurate low-gravity motion, cinematic lighting, and a "
"breath-taking, movie-like shot."
)
video, audio = pipe(
prompt=prompt,
negative_prompt=DEFAULT_NEGATIVE_PROMPT,
width=width,
height=height,
num_frames=121,
frame_rate=frame_rate,
num_inference_steps=30,
guidance_scale=3.0,
stg_scale=1.0,
modality_scale=3.0,
guidance_rescale=0.7,
audio_guidance_scale=7.0,
audio_stg_scale=1.0,
audio_modality_scale=3.0,
audio_guidance_rescale=0.7,
spatio_temporal_guidance_blocks=[28],
use_cross_timestep=True,
system_prompt=T2V_DEFAULT_SYSTEM_PROMPT,
generator=generator,
output_type="np",
return_dict=False,
)
encode_video(
video[0],
fps=frame_rate,
audio=audio[0].float().cpu(),
audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
output_path="ltx2_3_t2v_stage_1.mp4",
)
```
## LTX2Pipeline
[[autodoc]] LTX2Pipeline

View File

@@ -248,6 +248,24 @@ Refer to the [diffusers/benchmarks](https://huggingface.co/datasets/diffusers/be
The [diffusers-torchao](https://github.com/sayakpaul/diffusers-torchao#benchmarking-results) repository also contains benchmarking results for compiled versions of Flux and CogVideoX.
## Kernels
[Kernels](https://huggingface.co/docs/kernels/index) is a library for building, distributing, and loading optimized compute kernels on the [Hub](https://huggingface.co/kernels-community). It supports [attention](./attention_backends#set_attention_backend) kernels and custom CUDA kernels for operations like RMSNorm, GEGLU, RoPE, and AdaLN.
The [Diffusers Pipeline Integration](https://github.com/huggingface/kernels/blob/main/skills/cuda-kernels/references/diffusers-integration.md) guide shows how to integrate a kernel with the [add cuda-kernels](https://github.com/huggingface/kernels/blob/main/skills/cuda-kernels/SKILL.md) skill. This skill enables an agent, like Claude or Codex, to write custom kernels targeted towards a specific model and your hardware.
> [!TIP]
> Install the [add cuda-kernels](https://github.com/huggingface/kernels/blob/main/skills/cuda-kernels/SKILL.md) skill to teach an agent how to write a kernel. The [Custom kernels for all from Codex and Claude](https://huggingface.co/blog/custom-cuda-kernels-agent-skills) blog post covers this in more detail.
For example, a custom RMSNorm kernel (generated by the `add cuda-kernels` skill) with [torch.compile](#torchcompile) speeds up LTX-Video generation 1.43x on an H100.
<iframe
src="https://huggingface.co/datasets/docs-benchmarks/kernel-ltx-video/embed/viewer/default/train"
frameborder="0"
width="100%"
height="560px"
></iframe>
## Dynamic quantization
[Dynamic quantization](https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html) improves inference speed by reducing precision to enable faster math operations. This particular type of quantization determines how to scale the activations based on the data at runtime rather than using a fixed scaling factor. As a result, the scaling factor is more accurately aligned with the data.

View File

@@ -1105,7 +1105,7 @@ def main(args):
# text encoding.
captions = batch["captions"]
text_encoding_pipeline = text_encoding_pipeline.to("cuda")
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(
captions, prompt_2=None

View File

@@ -1251,7 +1251,7 @@ def main(args):
# text encoding.
captions = batch["captions"]
text_encoding_pipeline = text_encoding_pipeline.to("cuda")
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(
captions, prompt_2=None

View File

@@ -1,6 +1,155 @@
# Copyright 2026 Lightricks and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Pre-trained sigma values for distilled model are taken from
# https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-pipelines/src/ltx_pipelines/utils/constants.py
DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875]
# Reduced schedule for super-resolution stage 2 (subset of distilled values)
STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875]
# Default negative prompt from
# https://github.com/Lightricks/LTX-2/blob/ae855f8538843825f9015a419cf4ba5edaf5eec2/packages/ltx-pipelines/src/ltx_pipelines/utils/constants.py#L131-L143
DEFAULT_NEGATIVE_PROMPT = (
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
)
# System prompts for prompt enhancement
# https://github.com/Lightricks/LTX-2/blob/ae855f8538843825f9015a419cf4ba5edaf5eec2/packages/ltx-core/src/ltx_core/text_encoders/gemma/encoders/prompts/gemma_t2v_system_prompt.txt#L1
# Disable line-too-long rule in ruff to keep the prompts exactly the same (e.g. in terms of newlines)
# Supported in ruff>=0.15.0
# ruff: disable[E501]
T2V_DEFAULT_SYSTEM_PROMPT = """
You are a Creative Assistant. Given a user's raw input prompt describing a scene or concept, expand it into a detailed
video generation prompt with specific visuals and integrated audio to guide a text-to-video model.
#### Guidelines
- Strictly follow all aspects of the user's raw input: include every element requested (style, visuals, motions,
actions, camera movement, audio).
- If the input is vague, invent concrete details: lighting, textures, materials, scene settings, etc.
- For characters: describe gender, clothing, hair, expressions. DO NOT invent unrequested characters.
- Use active language: present-progressive verbs ("is walking," "speaking"). If no action specified, describe natural
movements.
- Maintain chronological flow: use temporal connectors ("as," "then," "while").
- Audio layer: Describe complete soundscape (background audio, ambient sounds, SFX, speech/music when requested).
Integrate sounds chronologically alongside actions. Be specific (e.g., "soft footsteps on tile"), not vague (e.g.,
"ambient sound is present").
- Speech (only when requested):
- For ANY speech-related input (talking, conversation, singing, etc.), ALWAYS include exact words in quotes with
voice characteristics (e.g., "The man says in an excited voice: 'You won't believe what I just saw!'").
- Specify language if not English and accent if relevant.
- Style: Include visual style at the beginning: "Style: <style>, <rest of prompt>." Default to cinematic-realistic if
unspecified. Omit if unclear.
- Visual and audio only: NO non-visual/auditory senses (smell, taste, touch).
- Restrained language: Avoid dramatic/exaggerated terms. Use mild, natural phrasing.
- Colors: Use plain terms ("red dress"), not intensified ("vibrant blue," "bright red").
- Lighting: Use neutral descriptions ("soft overhead light"), not harsh ("blinding light").
- Facial features: Use delicate modifiers for subtle features (i.e., "subtle freckles").
#### Important notes:
- Analyze the user's raw input carefully. In cases of FPV or POV, exclude the description of the subject whose POV is
requested.
- Camera motion: DO NOT invent camera motion unless requested by the user.
- Speech: DO NOT modify user-provided character dialogue unless it's a typo.
- No timestamps or cuts: DO NOT use timestamps or describe scene cuts unless explicitly requested.
- Format: DO NOT use phrases like "The scene opens with...". Start directly with Style (optional) and chronological
scene description.
- Format: DO NOT start your response with special characters.
- DO NOT invent dialogue unless the user mentions speech/talking/singing/conversation.
- If the user's raw input prompt is highly detailed, chronological and in the requested format: DO NOT make major edits
or introduce new elements. Add/enhance audio descriptions if missing.
#### Output Format (Strict):
- Single continuous paragraph in natural language (English).
- NO titles, headings, prefaces, code fences, or Markdown.
- If unsafe/invalid, return original user prompt. Never ask questions or clarifications.
Your output quality is CRITICAL. Generate visually rich, dynamic prompts with integrated audio for high-quality video
generation.
#### Example Input: "A woman at a coffee shop talking on the phone" Output: Style: realistic with cinematic lighting.
In a medium close-up, a woman in her early 30s with shoulder-length brown hair sits at a small wooden table by the
window. She wears a cream-colored turtleneck sweater, holding a white ceramic coffee cup in one hand and a smartphone
to her ear with the other. Ambient cafe sounds fill the space—espresso machine hiss, quiet conversations, gentle
clinking of cups. The woman listens intently, nodding slightly, then takes a sip of her coffee and sets it down with a
soft clink. Her face brightens into a warm smile as she speaks in a clear, friendly voice, 'That sounds perfect! I'd
love to meet up this weekend. How about Saturday afternoon?' She laughs softly—a genuine chuckle—and shifts in her
chair. Behind her, other patrons move subtly in and out of focus. 'Great, I'll see you then,' she concludes cheerfully,
lowering the phone.
"""
# ruff: enable[E501]
# ruff: disable[E501]
I2V_DEFAULT_SYSTEM_PROMPT = """
You are a Creative Assistant writing concise, action-focused image-to-video prompts. Given an image (first frame) and
user Raw Input Prompt, generate a prompt to guide video generation from that image.
#### Guidelines:
- Analyze the Image: Identify Subject, Setting, Elements, Style and Mood.
- Follow user Raw Input Prompt: Include all requested motion, actions, camera movements, audio, and details. If in
conflict with the image, prioritize user request while maintaining visual consistency (describe transition from image
to user's scene).
- Describe only changes from the image: Don't reiterate established visual details. Inaccurate descriptions may cause
scene cuts.
- Active language: Use present-progressive verbs ("is walking," "speaking"). If no action specified, describe natural
movements.
- Chronological flow: Use temporal connectors ("as," "then," "while").
- Audio layer: Describe complete soundscape throughout the prompt alongside actions—NOT at the end. Align audio
intensity with action tempo. Include natural background audio, ambient sounds, effects, speech or music (when
requested). Be specific (e.g., "soft footsteps on tile") not vague (e.g., "ambient sound").
- Speech (only when requested): Provide exact words in quotes with character's visual/voice characteristics (e.g., "The
tall man speaks in a low, gravelly voice"), language if not English and accent if relevant. If general conversation
mentioned without text, generate contextual quoted dialogue. (i.e., "The man is talking" input -> the output should
include exact spoken words, like: "The man is talking in an excited voice saying: 'You won't believe what I just
saw!' His hands gesture expressively as he speaks, eyebrows raised with enthusiasm. The ambient sound of a quiet room
underscores his animated speech.")
- Style: Include visual style at beginning: "Style: <style>, <rest of prompt>." If unclear, omit to avoid conflicts.
- Visual and audio only: Describe only what is seen and heard. NO smell, taste, or tactile sensations.
- Restrained language: Avoid dramatic terms. Use mild, natural, understated phrasing.
#### Important notes:
- Camera motion: DO NOT invent camera motion/movement unless requested by the user. Make sure to include camera motion
only if specified in the input.
- Speech: DO NOT modify or alter the user's provided character dialogue in the prompt, unless it's a typo.
- No timestamps or cuts: DO NOT use timestamps or describe scene cuts unless explicitly requested.
- Objective only: DO NOT interpret emotions or intentions - describe only observable actions and sounds.
- Format: DO NOT use phrases like "The scene opens with..." / "The video starts...". Start directly with Style
(optional) and chronological scene description.
- Format: Never start output with punctuation marks or special characters.
- DO NOT invent dialogue unless the user mentions speech/talking/singing/conversation.
- Your performance is CRITICAL. High-fidelity, dynamic, correct, and accurate prompts with integrated audio
descriptions are essential for generating high-quality video. Your goal is flawless execution of these rules.
#### Output Format (Strict):
- Single concise paragraph in natural English. NO titles, headings, prefaces, sections, code fences, or Markdown.
- If unsafe/invalid, return original user prompt. Never ask questions or clarifications.
#### Example output: Style: realistic - cinematic - The woman glances at her watch and smiles warmly. She speaks in a
cheerful, friendly voice, "I think we're right on time!" In the background, a café barista prepares drinks at the
counter. The barista calls out in a clear, upbeat tone, "Two cappuccinos ready!" The sound of the espresso machine
hissing softly blends with gentle background chatter and the light clinking of cups on saucers.
"""
# ruff: enable[E501]

View File

@@ -12,53 +12,71 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import HunyuanVideo15Transformer3DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
BaseModelTesterConfig,
ModelTesterMixin,
TrainingTesterMixin,
)
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class HunyuanVideo15TransformerTesterConfig(BaseModelTesterConfig):
class HunyuanVideo15Transformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideo15Transformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
model_split_percents = [0.99, 0.99, 0.99]
text_embed_dim = 16
text_embed_2_dim = 8
image_embed_dim = 12
@property
def model_class(self):
return HunyuanVideo15Transformer3DModel
def dummy_input(self):
batch_size = 1
num_channels = 4
num_frames = 1
height = 8
width = 8
sequence_length = 6
sequence_length_2 = 4
image_sequence_length = 3
@property
def main_input_name(self) -> str:
return "hidden_states"
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, self.text_embed_dim), device=torch_device)
encoder_hidden_states_2 = torch.randn(
(batch_size, sequence_length_2, self.text_embed_2_dim), device=torch_device
)
encoder_attention_mask = torch.ones((batch_size, sequence_length), device=torch_device)
encoder_attention_mask_2 = torch.ones((batch_size, sequence_length_2), device=torch_device)
# All zeros for inducing T2V path in the model.
image_embeds = torch.zeros((batch_size, image_sequence_length, self.image_embed_dim), device=torch_device)
@property
def model_split_percents(self) -> list:
return [0.99, 0.99, 0.99]
@property
def output_shape(self) -> tuple:
return (4, 1, 8, 8)
@property
def input_shape(self) -> tuple:
return (4, 1, 8, 8)
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict:
return {
"hidden_states": hidden_states,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
"encoder_attention_mask": encoder_attention_mask,
"encoder_hidden_states_2": encoder_hidden_states_2,
"encoder_attention_mask_2": encoder_attention_mask_2,
"image_embeds": image_embeds,
}
@property
def input_shape(self):
return (4, 1, 8, 8)
@property
def output_shape(self):
return (4, 1, 8, 8)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 4,
"out_channels": 4,
"num_attention_heads": 2,
@@ -75,40 +93,9 @@ class HunyuanVideo15TransformerTesterConfig(BaseModelTesterConfig):
"target_size": 16,
"task_type": "t2v",
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
num_channels = 4
num_frames = 1
height = 8
width = 8
sequence_length = 6
sequence_length_2 = 4
image_sequence_length = 3
return {
"hidden_states": randn_tensor(
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
),
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, self.text_embed_dim), generator=self.generator, device=torch_device
),
"encoder_hidden_states_2": randn_tensor(
(batch_size, sequence_length_2, self.text_embed_2_dim), generator=self.generator, device=torch_device
),
"encoder_attention_mask": torch.ones((batch_size, sequence_length), device=torch_device),
"encoder_attention_mask_2": torch.ones((batch_size, sequence_length_2), device=torch_device),
"image_embeds": torch.zeros(
(batch_size, image_sequence_length, self.image_embed_dim), device=torch_device
),
}
class TestHunyuanVideo15Transformer(HunyuanVideo15TransformerTesterConfig, ModelTesterMixin):
pass
class TestHunyuanVideo15TransformerTraining(HunyuanVideo15TransformerTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"HunyuanVideo15Transformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

View File

@@ -13,100 +13,51 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import HunyuanDiT2DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
ModelTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class HunyuanDiTTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return HunyuanDiT2DModel
class HunyuanDiTTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanDiT2DModel
main_input_name = "hidden_states"
@property
def pretrained_model_name_or_path(self):
return "hf-internal-testing/tiny-hunyuan-dit-pipe"
@property
def pretrained_model_kwargs(self):
return {"subfolder": "transformer"}
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def output_shape(self) -> tuple:
return (8, 8, 8)
@property
def input_shape(self) -> tuple:
return (4, 8, 8)
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict:
return {
"sample_size": 8,
"patch_size": 2,
"in_channels": 4,
"num_layers": 1,
"attention_head_dim": 8,
"num_attention_heads": 2,
"cross_attention_dim": 8,
"cross_attention_dim_t5": 8,
"pooled_projection_dim": 4,
"hidden_size": 16,
"text_len": 4,
"text_len_t5": 4,
"activation_fn": "gelu-approximate",
}
def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]:
def dummy_input(self):
batch_size = 2
num_channels = 4
height = width = 8
embedding_dim = 8
sequence_length = 4
sequence_length_t5 = 4
hidden_states = randn_tensor(
(batch_size, num_channels, height, width), generator=self.generator, device=torch_device
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
text_embedding_mask = torch.ones(size=(batch_size, sequence_length)).to(torch_device)
encoder_hidden_states_t5 = randn_tensor(
(batch_size, sequence_length_t5, embedding_dim), generator=self.generator, device=torch_device
)
encoder_hidden_states_t5 = torch.randn((batch_size, sequence_length_t5, embedding_dim)).to(torch_device)
text_embedding_mask_t5 = torch.ones(size=(batch_size, sequence_length_t5)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,), generator=self.generator).float().to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,), dtype=encoder_hidden_states.dtype).to(torch_device)
original_size = [1024, 1024]
target_size = [16, 16]
crops_coords_top_left = [0, 0]
add_time_ids = list(original_size + target_size + crops_coords_top_left)
add_time_ids = torch.tensor([add_time_ids] * batch_size, dtype=torch.float32).to(torch_device)
add_time_ids = torch.tensor([add_time_ids, add_time_ids], dtype=encoder_hidden_states.dtype).to(torch_device)
style = torch.zeros(size=(batch_size,), dtype=int).to(torch_device)
image_rotary_emb = [
torch.ones(size=(1, 8), dtype=torch.float32),
torch.zeros(size=(1, 8), dtype=torch.float32),
torch.ones(size=(1, 8), dtype=encoder_hidden_states.dtype),
torch.zeros(size=(1, 8), dtype=encoder_hidden_states.dtype),
]
return {
@@ -121,26 +72,42 @@ class HunyuanDiTTesterConfig(BaseModelTesterConfig):
"image_rotary_emb": image_rotary_emb,
}
@property
def input_shape(self):
return (4, 8, 8)
@property
def output_shape(self):
return (8, 8, 8)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"sample_size": 8,
"patch_size": 2,
"in_channels": 4,
"num_layers": 1,
"attention_head_dim": 8,
"num_attention_heads": 2,
"cross_attention_dim": 8,
"cross_attention_dim_t5": 8,
"pooled_projection_dim": 4,
"hidden_size": 16,
"text_len": 4,
"text_len_t5": 4,
"activation_fn": "gelu-approximate",
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
class TestHunyuanDiT(HunyuanDiTTesterConfig, ModelTesterMixin):
def test_output(self):
batch_size = self.get_dummy_inputs()[self.main_input_name].shape[0]
super().test_output(expected_output_shape=(batch_size,) + self.output_shape)
super().test_output(
expected_output_shape=(self.dummy_input[self.main_input_name].shape[0],) + self.output_shape
)
@unittest.skip("HunyuanDIT use a custom processor HunyuanAttnProcessor2_0")
def test_set_xformers_attn_processor_for_determinism(self):
pass
class TestHunyuanDiTTraining(HunyuanDiTTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"HunyuanDiT2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestHunyuanDiTCompile(HunyuanDiTTesterConfig, TorchCompileTesterMixin):
pass
class TestHunyuanDiTBitsAndBytes(HunyuanDiTTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for HunyuanDiT."""
class TestHunyuanDiTTorchAo(HunyuanDiTTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for HunyuanDiT."""
@unittest.skip("HunyuanDIT use a custom processor HunyuanAttnProcessor2_0")
def test_set_attn_processor_for_determinism(self):
pass

View File

@@ -12,59 +12,64 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import HunyuanVideoTransformer3DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
ModelTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
# ======================== HunyuanVideo Text-to-Video ========================
class HunyuanVideoTransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return HunyuanVideoTransformer3DModel
class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def pretrained_model_name_or_path(self):
return "hf-internal-testing/tiny-random-hunyuanvideo"
def dummy_input(self):
batch_size = 1
num_channels = 4
num_frames = 1
height = 16
width = 16
text_encoder_embedding_dim = 16
pooled_projection_dim = 8
sequence_length = 12
@property
def pretrained_model_kwargs(self):
return {"subfolder": "transformer"}
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device)
encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device, dtype=torch.float32)
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def output_shape(self) -> tuple:
return (4, 1, 16, 16)
@property
def input_shape(self) -> tuple:
return (4, 1, 16, 16)
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict:
return {
"hidden_states": hidden_states,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
"pooled_projections": pooled_projections,
"encoder_attention_mask": encoder_attention_mask,
"guidance": guidance,
}
@property
def input_shape(self):
return (4, 1, 16, 16)
@property
def output_shape(self):
return (4, 1, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 4,
"out_channels": 4,
"num_attention_heads": 2,
@@ -80,9 +85,30 @@ class HunyuanVideoTransformerTesterConfig(BaseModelTesterConfig):
"rope_axes_dim": (2, 4, 4),
"image_condition_type": None,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
num_channels = 4
def test_gradient_checkpointing_is_applied(self):
expected_set = {"HunyuanVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class HunyuanTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
def prepare_init_args_and_inputs_for_common(self):
return HunyuanVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def dummy_input(self):
batch_size = 1
num_channels = 8
num_frames = 1
height = 16
width = 16
@@ -90,74 +116,105 @@ class HunyuanVideoTransformerTesterConfig(BaseModelTesterConfig):
pooled_projection_dim = 8
sequence_length = 12
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device)
encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device, dtype=torch.float32)
return {
"hidden_states": randn_tensor(
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, text_encoder_embedding_dim),
generator=self.generator,
device=torch_device,
),
"pooled_projections": randn_tensor(
(batch_size, pooled_projection_dim), generator=self.generator, device=torch_device
),
"encoder_attention_mask": torch.ones((batch_size, sequence_length)).to(torch_device),
"guidance": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(
torch_device, dtype=torch.float32
),
"hidden_states": hidden_states,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
"pooled_projections": pooled_projections,
"encoder_attention_mask": encoder_attention_mask,
"guidance": guidance,
}
@property
def input_shape(self):
return (8, 1, 16, 16)
class TestHunyuanVideoTransformer(HunyuanVideoTransformerTesterConfig, ModelTesterMixin):
pass
@property
def output_shape(self):
return (4, 1, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 8,
"out_channels": 4,
"num_attention_heads": 2,
"attention_head_dim": 10,
"num_layers": 1,
"num_single_layers": 1,
"num_refiner_layers": 1,
"patch_size": 1,
"patch_size_t": 1,
"guidance_embeds": True,
"text_embed_dim": 16,
"pooled_projection_dim": 8,
"rope_axes_dim": (2, 4, 4),
"image_condition_type": None,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_output(self):
super().test_output(expected_output_shape=(1, *self.output_shape))
class TestHunyuanVideoTransformerTraining(HunyuanVideoTransformerTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"HunyuanVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestHunyuanVideoTransformerCompile(HunyuanVideoTransformerTesterConfig, TorchCompileTesterMixin):
pass
class HunyuanSkyreelsImageToVideoCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
def prepare_init_args_and_inputs_for_common(self):
return HunyuanSkyreelsImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
class TestHunyuanVideoTransformerBitsAndBytes(HunyuanVideoTransformerTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for HunyuanVideo Transformer."""
class TestHunyuanVideoTransformerTorchAo(HunyuanVideoTransformerTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for HunyuanVideo Transformer."""
# ======================== HunyuanVideo Image-to-Video (Latent Concat) ========================
class HunyuanVideoI2VTransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return HunyuanVideoTransformer3DModel
class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def main_input_name(self) -> str:
return "hidden_states"
def dummy_input(self):
batch_size = 1
num_channels = 2 * 4 + 1
num_frames = 1
height = 16
width = 16
text_encoder_embedding_dim = 16
pooled_projection_dim = 8
sequence_length = 12
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device)
encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
return {
"hidden_states": hidden_states,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
"pooled_projections": pooled_projections,
"encoder_attention_mask": encoder_attention_mask,
}
@property
def output_shape(self) -> tuple:
return (4, 1, 16, 16)
@property
def input_shape(self) -> tuple:
def input_shape(self):
return (8, 1, 16, 16)
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def output_shape(self):
return (4, 1, 16, 16)
def get_init_dict(self) -> dict:
return {
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 2 * 4 + 1,
"out_channels": 4,
"num_attention_heads": 2,
@@ -173,9 +230,33 @@ class HunyuanVideoI2VTransformerTesterConfig(BaseModelTesterConfig):
"rope_axes_dim": (2, 4, 4),
"image_condition_type": "latent_concat",
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
num_channels = 2 * 4 + 1
def test_output(self):
super().test_output(expected_output_shape=(1, *self.output_shape))
def test_gradient_checkpointing_is_applied(self):
expected_set = {"HunyuanVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class HunyuanImageToVideoCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
def prepare_init_args_and_inputs_for_common(self):
return HunyuanVideoImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def dummy_input(self):
batch_size = 1
num_channels = 2
num_frames = 1
height = 16
width = 16
@@ -183,58 +264,32 @@ class HunyuanVideoI2VTransformerTesterConfig(BaseModelTesterConfig):
pooled_projection_dim = 8
sequence_length = 12
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device)
encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device, dtype=torch.float32)
return {
"hidden_states": randn_tensor(
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, text_encoder_embedding_dim),
generator=self.generator,
device=torch_device,
),
"pooled_projections": randn_tensor(
(batch_size, pooled_projection_dim), generator=self.generator, device=torch_device
),
"encoder_attention_mask": torch.ones((batch_size, sequence_length)).to(torch_device),
"hidden_states": hidden_states,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
"pooled_projections": pooled_projections,
"encoder_attention_mask": encoder_attention_mask,
"guidance": guidance,
}
class TestHunyuanVideoI2VTransformer(HunyuanVideoI2VTransformerTesterConfig, ModelTesterMixin):
def test_output(self):
super().test_output(expected_output_shape=(1, *self.output_shape))
class TestHunyuanVideoI2VTransformerCompile(HunyuanVideoI2VTransformerTesterConfig, TorchCompileTesterMixin):
pass
# ======================== HunyuanVideo Token Replace Image-to-Video ========================
class HunyuanVideoTokenReplaceTransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return HunyuanVideoTransformer3DModel
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def output_shape(self) -> tuple:
return (4, 1, 16, 16)
@property
def input_shape(self) -> tuple:
def input_shape(self):
return (8, 1, 16, 16)
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def output_shape(self):
return (4, 1, 16, 16)
def get_init_dict(self) -> dict:
return {
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 2,
"out_channels": 4,
"num_attention_heads": 2,
@@ -250,42 +305,19 @@ class HunyuanVideoTokenReplaceTransformerTesterConfig(BaseModelTesterConfig):
"rope_axes_dim": (2, 4, 4),
"image_condition_type": "token_replace",
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
num_channels = 2
num_frames = 1
height = 16
width = 16
text_encoder_embedding_dim = 16
pooled_projection_dim = 8
sequence_length = 12
return {
"hidden_states": randn_tensor(
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, text_encoder_embedding_dim),
generator=self.generator,
device=torch_device,
),
"pooled_projections": randn_tensor(
(batch_size, pooled_projection_dim), generator=self.generator, device=torch_device
),
"encoder_attention_mask": torch.ones((batch_size, sequence_length)).to(torch_device),
"guidance": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(
torch_device, dtype=torch.float32
),
}
class TestHunyuanVideoTokenReplaceTransformer(HunyuanVideoTokenReplaceTransformerTesterConfig, ModelTesterMixin):
def test_output(self):
super().test_output(expected_output_shape=(1, *self.output_shape))
def test_gradient_checkpointing_is_applied(self):
expected_set = {"HunyuanVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestHunyuanVideoTokenReplaceTransformerCompile(
HunyuanVideoTokenReplaceTransformerTesterConfig, TorchCompileTesterMixin
):
pass
class HunyuanVideoTokenReplaceCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
def prepare_init_args_and_inputs_for_common(self):
return HunyuanVideoTokenReplaceImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()

View File

@@ -12,49 +12,84 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import HunyuanVideoFramepackTransformer3DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
BaseModelTesterConfig,
ModelTesterMixin,
TrainingTesterMixin,
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class HunyuanVideoFramepackTransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return HunyuanVideoFramepackTransformer3DModel
class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideoFramepackTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
model_split_percents = [0.5, 0.7, 0.9]
@property
def main_input_name(self) -> str:
return "hidden_states"
def dummy_input(self):
batch_size = 1
num_channels = 4
num_frames = 3
height = 4
width = 4
text_encoder_embedding_dim = 16
image_encoder_embedding_dim = 16
pooled_projection_dim = 8
sequence_length = 12
@property
def model_split_percents(self) -> list:
return [0.5, 0.7, 0.9]
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device)
encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
image_embeds = torch.randn((batch_size, sequence_length, image_encoder_embedding_dim)).to(torch_device)
indices_latents = torch.ones((3,)).to(torch_device)
latents_clean = torch.randn((batch_size, num_channels, num_frames - 1, height, width)).to(torch_device)
indices_latents_clean = torch.ones((num_frames - 1,)).to(torch_device)
latents_history_2x = torch.randn((batch_size, num_channels, num_frames - 1, height, width)).to(torch_device)
indices_latents_history_2x = torch.ones((num_frames - 1,)).to(torch_device)
latents_history_4x = torch.randn((batch_size, num_channels, (num_frames - 1) * 4, height, width)).to(
torch_device
)
indices_latents_history_4x = torch.ones(((num_frames - 1) * 4,)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
@property
def output_shape(self) -> tuple:
return (4, 3, 4, 4)
@property
def input_shape(self) -> tuple:
return (4, 3, 4, 4)
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict:
return {
"hidden_states": hidden_states,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
"pooled_projections": pooled_projections,
"encoder_attention_mask": encoder_attention_mask,
"guidance": guidance,
"image_embeds": image_embeds,
"indices_latents": indices_latents,
"latents_clean": latents_clean,
"indices_latents_clean": indices_latents_clean,
"latents_history_2x": latents_history_2x,
"indices_latents_history_2x": indices_latents_history_2x,
"latents_history_4x": latents_history_4x,
"indices_latents_history_4x": indices_latents_history_4x,
}
@property
def input_shape(self):
return (4, 3, 4, 4)
@property
def output_shape(self):
return (4, 3, 4, 4)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 4,
"out_channels": 4,
"num_attention_heads": 2,
@@ -73,64 +108,9 @@ class HunyuanVideoFramepackTransformerTesterConfig(BaseModelTesterConfig):
"image_proj_dim": 16,
"has_clean_x_embedder": True,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
num_channels = 4
num_frames = 3
height = 4
width = 4
text_encoder_embedding_dim = 16
image_encoder_embedding_dim = 16
pooled_projection_dim = 8
sequence_length = 12
return {
"hidden_states": randn_tensor(
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, text_encoder_embedding_dim),
generator=self.generator,
device=torch_device,
),
"pooled_projections": randn_tensor(
(batch_size, pooled_projection_dim), generator=self.generator, device=torch_device
),
"encoder_attention_mask": torch.ones((batch_size, sequence_length)).to(torch_device),
"guidance": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
"image_embeds": randn_tensor(
(batch_size, sequence_length, image_encoder_embedding_dim),
generator=self.generator,
device=torch_device,
),
"indices_latents": torch.ones((num_frames,)).to(torch_device),
"latents_clean": randn_tensor(
(batch_size, num_channels, num_frames - 1, height, width),
generator=self.generator,
device=torch_device,
),
"indices_latents_clean": torch.ones((num_frames - 1,)).to(torch_device),
"latents_history_2x": randn_tensor(
(batch_size, num_channels, num_frames - 1, height, width),
generator=self.generator,
device=torch_device,
),
"indices_latents_history_2x": torch.ones((num_frames - 1,)).to(torch_device),
"latents_history_4x": randn_tensor(
(batch_size, num_channels, (num_frames - 1) * 4, height, width),
generator=self.generator,
device=torch_device,
),
"indices_latents_history_4x": torch.ones(((num_frames - 1) * 4,)).to(torch_device),
}
class TestHunyuanVideoFramepackTransformer(HunyuanVideoFramepackTransformerTesterConfig, ModelTesterMixin):
pass
class TestHunyuanVideoFramepackTransformerTraining(HunyuanVideoFramepackTransformerTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"HunyuanVideoFramepackTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

View File

@@ -0,0 +1,242 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from diffusers.modular_pipelines import (
AutoPipelineBlocks,
ConditionalPipelineBlocks,
InputParam,
ModularPipelineBlocks,
)
class TextToImageBlock(ModularPipelineBlocks):
model_name = "text2img"
@property
def inputs(self):
return [InputParam(name="prompt")]
@property
def intermediate_outputs(self):
return []
@property
def description(self):
return "text-to-image workflow"
def __call__(self, components, state):
block_state = self.get_block_state(state)
block_state.workflow = "text2img"
self.set_block_state(state, block_state)
return components, state
class ImageToImageBlock(ModularPipelineBlocks):
model_name = "img2img"
@property
def inputs(self):
return [InputParam(name="prompt"), InputParam(name="image")]
@property
def intermediate_outputs(self):
return []
@property
def description(self):
return "image-to-image workflow"
def __call__(self, components, state):
block_state = self.get_block_state(state)
block_state.workflow = "img2img"
self.set_block_state(state, block_state)
return components, state
class InpaintBlock(ModularPipelineBlocks):
model_name = "inpaint"
@property
def inputs(self):
return [InputParam(name="prompt"), InputParam(name="image"), InputParam(name="mask")]
@property
def intermediate_outputs(self):
return []
@property
def description(self):
return "inpaint workflow"
def __call__(self, components, state):
block_state = self.get_block_state(state)
block_state.workflow = "inpaint"
self.set_block_state(state, block_state)
return components, state
class ConditionalImageBlocks(ConditionalPipelineBlocks):
block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock]
block_names = ["inpaint", "img2img", "text2img"]
block_trigger_inputs = ["mask", "image"]
default_block_name = "text2img"
@property
def description(self):
return "Conditional image blocks for testing"
def select_block(self, mask=None, image=None) -> str | None:
if mask is not None:
return "inpaint"
if image is not None:
return "img2img"
return None # falls back to default_block_name
class OptionalConditionalBlocks(ConditionalPipelineBlocks):
block_classes = [InpaintBlock, ImageToImageBlock]
block_names = ["inpaint", "img2img"]
block_trigger_inputs = ["mask", "image"]
default_block_name = None # no default; block can be skipped
@property
def description(self):
return "Optional conditional blocks (skippable)"
def select_block(self, mask=None, image=None) -> str | None:
if mask is not None:
return "inpaint"
if image is not None:
return "img2img"
return None
class AutoImageBlocks(AutoPipelineBlocks):
block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock]
block_names = ["inpaint", "img2img", "text2img"]
block_trigger_inputs = ["mask", "image", None]
@property
def description(self):
return "Auto image blocks for testing"
class TestConditionalPipelineBlocksSelectBlock:
def test_select_block_with_mask(self):
blocks = ConditionalImageBlocks()
assert blocks.select_block(mask="something") == "inpaint"
def test_select_block_with_image(self):
blocks = ConditionalImageBlocks()
assert blocks.select_block(image="something") == "img2img"
def test_select_block_with_mask_and_image(self):
blocks = ConditionalImageBlocks()
assert blocks.select_block(mask="m", image="i") == "inpaint"
def test_select_block_no_triggers_returns_none(self):
blocks = ConditionalImageBlocks()
assert blocks.select_block() is None
def test_select_block_explicit_none_values(self):
blocks = ConditionalImageBlocks()
assert blocks.select_block(mask=None, image=None) is None
class TestConditionalPipelineBlocksWorkflowSelection:
def test_default_workflow_when_no_triggers(self):
blocks = ConditionalImageBlocks()
execution = blocks.get_execution_blocks()
assert execution is not None
assert isinstance(execution, TextToImageBlock)
def test_mask_trigger_selects_inpaint(self):
blocks = ConditionalImageBlocks()
execution = blocks.get_execution_blocks(mask=True)
assert isinstance(execution, InpaintBlock)
def test_image_trigger_selects_img2img(self):
blocks = ConditionalImageBlocks()
execution = blocks.get_execution_blocks(image=True)
assert isinstance(execution, ImageToImageBlock)
def test_mask_and_image_selects_inpaint(self):
blocks = ConditionalImageBlocks()
execution = blocks.get_execution_blocks(mask=True, image=True)
assert isinstance(execution, InpaintBlock)
def test_skippable_block_returns_none(self):
blocks = OptionalConditionalBlocks()
execution = blocks.get_execution_blocks()
assert execution is None
def test_skippable_block_still_selects_when_triggered(self):
blocks = OptionalConditionalBlocks()
execution = blocks.get_execution_blocks(image=True)
assert isinstance(execution, ImageToImageBlock)
class TestAutoPipelineBlocksSelectBlock:
def test_auto_select_mask(self):
blocks = AutoImageBlocks()
assert blocks.select_block(mask="m") == "inpaint"
def test_auto_select_image(self):
blocks = AutoImageBlocks()
assert blocks.select_block(image="i") == "img2img"
def test_auto_select_default(self):
blocks = AutoImageBlocks()
# No trigger -> returns None -> falls back to default (text2img)
assert blocks.select_block() is None
def test_auto_select_priority_order(self):
blocks = AutoImageBlocks()
assert blocks.select_block(mask="m", image="i") == "inpaint"
class TestAutoPipelineBlocksWorkflowSelection:
def test_auto_default_workflow(self):
blocks = AutoImageBlocks()
execution = blocks.get_execution_blocks()
assert isinstance(execution, TextToImageBlock)
def test_auto_mask_workflow(self):
blocks = AutoImageBlocks()
execution = blocks.get_execution_blocks(mask=True)
assert isinstance(execution, InpaintBlock)
def test_auto_image_workflow(self):
blocks = AutoImageBlocks()
execution = blocks.get_execution_blocks(image=True)
assert isinstance(execution, ImageToImageBlock)
class TestConditionalPipelineBlocksStructure:
def test_block_names_accessible(self):
blocks = ConditionalImageBlocks()
sub = dict(blocks.sub_blocks)
assert set(sub.keys()) == {"inpaint", "img2img", "text2img"}
def test_sub_block_types(self):
blocks = ConditionalImageBlocks()
sub = dict(blocks.sub_blocks)
assert isinstance(sub["inpaint"], InpaintBlock)
assert isinstance(sub["img2img"], ImageToImageBlock)
assert isinstance(sub["text2img"], TextToImageBlock)
def test_description(self):
blocks = ConditionalImageBlocks()
assert "Conditional" in blocks.description

View File

@@ -10,11 +10,6 @@ from huggingface_hub import hf_hub_download
import diffusers
from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks
from diffusers.guiders import ClassifierFreeGuidance
from diffusers.modular_pipelines import (
ConditionalPipelineBlocks,
LoopSequentialPipelineBlocks,
SequentialPipelineBlocks,
)
from diffusers.modular_pipelines.modular_pipeline_utils import (
ComponentSpec,
ConfigSpec,
@@ -25,7 +20,6 @@ from diffusers.modular_pipelines.modular_pipeline_utils import (
from diffusers.utils import logging
from ..testing_utils import (
CaptureLogger,
backend_empty_cache,
numpy_cosine_similarity_distance,
require_accelerator,
@@ -498,117 +492,6 @@ class ModularGuiderTesterMixin:
assert max_diff > expected_max_diff, "Output with CFG must be different from normal inference"
class TestCustomBlockRequirements:
def get_dummy_block_pipe(self):
class DummyBlockOne:
# keep two arbitrary deps so that we can test warnings.
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
# keep two dependencies that will be available during testing.
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
pipe = SequentialPipelineBlocks.from_blocks_dict(
{"dummy_block_one": DummyBlockOne, "dummy_block_two": DummyBlockTwo}
)
return pipe
def get_dummy_conditional_block_pipe(self):
class DummyBlockOne:
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
class DummyConditionalBlocks(ConditionalPipelineBlocks):
block_classes = [DummyBlockOne, DummyBlockTwo]
block_names = ["block_one", "block_two"]
block_trigger_inputs = []
def select_block(self, **kwargs):
return "block_one"
return DummyConditionalBlocks()
def get_dummy_loop_block_pipe(self):
class DummyBlockOne:
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
return LoopSequentialPipelineBlocks.from_blocks_dict({"block_one": DummyBlockOne, "block_two": DummyBlockTwo})
def test_sequential_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_block_pipe()
pipe.save_pretrained(str(tmp_path))
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
requirements = config["requirements"]
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == requirements
def test_sequential_block_requirements_warnings(self, tmp_path):
pipe = self.get_dummy_block_pipe()
logger = logging.get_logger("diffusers.modular_pipelines.modular_pipeline_utils")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pipe.save_pretrained(str(tmp_path))
template = "{req} was specified in the requirements but wasn't found in the current environment"
msg_xyz = template.format(req="xyz")
msg_abc = template.format(req="abc")
assert msg_xyz in str(cap_logger.out)
assert msg_abc in str(cap_logger.out)
def test_conditional_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_conditional_block_pipe()
pipe.save_pretrained(str(tmp_path))
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == config["requirements"]
def test_loop_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_loop_block_pipe()
pipe.save_pretrained(str(tmp_path))
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == config["requirements"]
class TestModularModelCardContent:
def create_mock_block(self, name="TestBlock", description="Test block description"):
class MockBlock:

View File

@@ -24,14 +24,18 @@ import torch
from diffusers import FluxTransformer2DModel
from diffusers.modular_pipelines import (
ComponentSpec,
ConditionalPipelineBlocks,
InputParam,
LoopSequentialPipelineBlocks,
ModularPipelineBlocks,
OutputParam,
PipelineState,
SequentialPipelineBlocks,
WanModularPipeline,
)
from diffusers.utils import logging
from ..testing_utils import nightly, require_torch, require_torch_accelerator, slow, torch_device
from ..testing_utils import CaptureLogger, nightly, require_torch, require_torch_accelerator, slow, torch_device
def _create_tiny_model_dir(model_dir):
@@ -463,6 +467,117 @@ class TestModularCustomBlocks:
assert output_prompt.startswith("Modular diffusers + ")
class TestCustomBlockRequirements:
def get_dummy_block_pipe(self):
class DummyBlockOne:
# keep two arbitrary deps so that we can test warnings.
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
# keep two dependencies that will be available during testing.
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
pipe = SequentialPipelineBlocks.from_blocks_dict(
{"dummy_block_one": DummyBlockOne, "dummy_block_two": DummyBlockTwo}
)
return pipe
def get_dummy_conditional_block_pipe(self):
class DummyBlockOne:
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
class DummyConditionalBlocks(ConditionalPipelineBlocks):
block_classes = [DummyBlockOne, DummyBlockTwo]
block_names = ["block_one", "block_two"]
block_trigger_inputs = []
def select_block(self, **kwargs):
return "block_one"
return DummyConditionalBlocks()
def get_dummy_loop_block_pipe(self):
class DummyBlockOne:
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
return LoopSequentialPipelineBlocks.from_blocks_dict({"block_one": DummyBlockOne, "block_two": DummyBlockTwo})
def test_sequential_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_block_pipe()
pipe.save_pretrained(str(tmp_path))
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
requirements = config["requirements"]
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == requirements
def test_sequential_block_requirements_warnings(self, tmp_path):
pipe = self.get_dummy_block_pipe()
logger = logging.get_logger("diffusers.modular_pipelines.modular_pipeline_utils")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pipe.save_pretrained(str(tmp_path))
template = "{req} was specified in the requirements but wasn't found in the current environment"
msg_xyz = template.format(req="xyz")
msg_abc = template.format(req="abc")
assert msg_xyz in str(cap_logger.out)
assert msg_abc in str(cap_logger.out)
def test_conditional_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_conditional_block_pipe()
pipe.save_pretrained(str(tmp_path))
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == config["requirements"]
def test_loop_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_loop_block_pipe()
pipe.save_pretrained(str(tmp_path))
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == config["requirements"]
@slow
@nightly
@require_torch