Compare commits

..

2 Commits

Author SHA1 Message Date
DN6
52bd90a3b3 update 2026-03-11 13:40:14 +05:30
DN6
b28e9204f7 update 2026-03-11 13:32:16 +05:30
5 changed files with 173 additions and 945 deletions

View File

@@ -2,7 +2,6 @@
line-length = 119
extend-exclude = [
"src/diffusers/pipelines/flux2/system_messages.py",
"src/diffusers/commands/daggr_app.py",
]
[tool.ruff.lint]

View File

@@ -1,726 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
import re
import tempfile
import types
import typing
from argparse import ArgumentParser, Namespace
from dataclasses import dataclass, field
from ..utils import logging
from . import BaseDiffusersCLICommand
logger = logging.get_logger("diffusers-cli/daggr")
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
INTERNAL_TYPE_NAMES = {
"Tensor",
"Generator",
}
INTERNAL_TYPE_FULL_NAMES = {
"torch.Tensor",
"torch.Generator",
"torch.dtype",
}
SLIDER_PARAMS = {
"height": {"minimum": 256, "maximum": 2048, "step": 64},
"width": {"minimum": 256, "maximum": 2048, "step": 64},
"num_inference_steps": {"minimum": 1, "maximum": 100, "step": 1},
"guidance_scale": {"minimum": 0, "maximum": 30, "step": 0.5},
"strength": {"minimum": 0, "maximum": 1, "step": 0.05},
"control_guidance_start": {"minimum": 0, "maximum": 1, "step": 0.05},
"control_guidance_end": {"minimum": 0, "maximum": 1, "step": 0.05},
"controlnet_conditioning_scale": {"minimum": 0, "maximum": 2, "step": 0.1},
}
DROPDOWN_PARAMS = {
"output_type": {"choices": ["pil", "np", "pt", "latent"], "value": "pil"},
}
DEFAULT_REQUIREMENTS = (
"torch --index-url https://download.pytorch.org/whl/cu129\n"
"diffusers\n"
"transformers\n"
"accelerate\n"
"sentencepiece\n"
"bitsandbytes\n"
"daggr\n"
"gradio\n"
)
# ---------------------------------------------------------------------------
# Code templates (string concat so formatters can't mangle newlines)
# ---------------------------------------------------------------------------
_HEADER_TEMPLATE = (
"import os\n"
"{extra_imports}\n"
"import gradio as gr\n"
"from daggr import FnNode, InputNode, Graph\n"
"\n"
"\n"
"_pipeline = None\n"
"_exec_blocks = None\n"
"_tensor_store = {{}}\n"
"\n"
"\n"
"def _get_pipeline():\n"
" global _pipeline, _exec_blocks\n"
" if _pipeline is None:\n"
" from diffusers import ModularPipeline\n"
" import torch\n"
"\n"
' _token = os.environ.get("HF_TOKEN")\n'
' _device = "cuda" if torch.cuda.is_available() else "cpu"\n'
" _pipeline = ModularPipeline.from_pretrained({repo_id}, trust_remote_code=True, token=_token)\n"
" _pipeline.load_components(torch_dtype=torch.bfloat16, device_map=_device)\n"
" _exec_blocks = {workflow_resolve_code}\n"
" return _pipeline, _exec_blocks"
)
_SAVE_IMAGE_TEMPLATE = (
"\n"
"\n"
"def _save_image(val):\n"
" from PIL import Image as PILImage\n"
"\n"
" if isinstance(val, list):\n"
" paths = [_save_image(item) for item in val]\n"
" return paths[0] if paths else None\n"
" if isinstance(val, PILImage.Image):\n"
' f = tempfile.NamedTemporaryFile(suffix=".png", delete=False)\n'
" val.save(f.name)\n"
" return f.name\n"
" return val"
)
_CACHE_DISABLE_TEMPLATE = (
"\n"
"# Disable result caching so every run executes all nodes fresh\n"
"from daggr.state import SessionState\n"
"SessionState.get_latest_result = lambda *a, **kw: None\n"
"SessionState.get_result_by_index = lambda *a, **kw: None\n"
"SessionState.save_result = lambda *a, **kw: None\n"
)
# ---------------------------------------------------------------------------
# Data structures
# ---------------------------------------------------------------------------
@dataclass
class BlockInfo:
name: str
class_name: str
description: str
inputs: list
outputs: list
user_inputs: list = field(default_factory=list)
port_connections: list = field(default_factory=list)
fixed_inputs: list = field(default_factory=list)
sub_block_names: list = field(default_factory=list)
# ---------------------------------------------------------------------------
# CLI command
# ---------------------------------------------------------------------------
def daggr_command_factory(args: Namespace):
return DaggrCommand(
repo_id=args.repo_id,
output=args.output,
workflow=getattr(args, "workflow", None),
trigger_inputs=getattr(args, "trigger_inputs", None),
deploy=getattr(args, "deploy", None),
hardware=getattr(args, "hardware", "cpu-basic"),
private=getattr(args, "private", False),
requirements=getattr(args, "requirements", None),
)
class DaggrCommand(BaseDiffusersCLICommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
daggr_parser = parser.add_parser("daggr", help="Generate a daggr app from a modular pipeline repo.")
daggr_parser.add_argument(
"repo_id",
type=str,
help="HuggingFace Hub repo ID containing a modular pipeline.",
)
daggr_parser.add_argument(
"--output",
type=str,
default=None,
help="Save the generated app to a file instead of launching it directly.",
)
daggr_parser.add_argument(
"--workflow",
type=str,
default=None,
help="Named workflow to resolve conditional blocks (e.g. 'text2image', 'image2image').",
)
daggr_parser.add_argument(
"--trigger-inputs",
nargs="*",
default=None,
help="Trigger input names for manual conditional resolution.",
)
daggr_parser.add_argument(
"--deploy",
type=str,
default=None,
metavar="SPACE_NAME",
help="Deploy the generated app to a HuggingFace Space via daggr deploy.",
)
daggr_parser.add_argument(
"--hardware",
type=str,
default="cpu-basic",
help="Hardware tier for the deployed Space (default: cpu-basic). E.g. a10g-small, a100-large.",
)
daggr_parser.add_argument(
"--private",
action="store_true",
default=False,
help="Make the deployed Space private.",
)
daggr_parser.add_argument(
"--requirements",
type=str,
default=None,
help="Path to a requirements.txt file for the deployed Space.",
)
daggr_parser.set_defaults(func=daggr_command_factory)
def __init__(
self,
repo_id: str,
output: str | None = None,
workflow: str | None = None,
trigger_inputs: list | None = None,
deploy: str | None = None,
hardware: str = "cpu-basic",
private: bool = False,
requirements: str | None = None,
):
self.repo_id = repo_id
self.output = output
self.workflow = workflow
self.trigger_inputs = trigger_inputs
self.deploy = deploy
self.hardware = hardware
self.private = private
self.requirements = requirements
def run(self):
from ..modular_pipelines.modular_pipeline import ModularPipeline
logger.info(f"Loading blocks from {self.repo_id}...")
pipeline = ModularPipeline.from_pretrained(self.repo_id, trust_remote_code=True)
blocks = pipeline._blocks
blocks_class_name = blocks.__class__.__name__
if self.workflow:
logger.info(f"Resolving workflow: {self.workflow}")
exec_blocks = blocks.get_workflow(self.workflow)
elif self.trigger_inputs:
trigger_kwargs = {name: True for name in self.trigger_inputs}
logger.info(f"Resolving with trigger inputs: {self.trigger_inputs}")
exec_blocks = blocks.get_execution_blocks(**trigger_kwargs)
else:
logger.info("Resolving default execution blocks...")
exec_blocks = blocks.get_execution_blocks()
block_infos = _get_block_info(exec_blocks)
_filter_outputs(block_infos)
_classify_inputs(block_infos)
workflow_label = self.workflow or "default"
workflow_resolve_code = self._get_workflow_resolve_code()
code = _generate_code(block_infos, self.repo_id, blocks_class_name, workflow_label, workflow_resolve_code)
try:
ast.parse(code)
except SyntaxError as e:
logger.warning(f"Generated code has syntax error: {e}")
if self.deploy:
self._deploy_to_space(code)
elif self.output:
with open(self.output, "w") as f:
f.write(code)
print(f"Generated daggr app: {self.output}")
print(f" Pipeline: {blocks_class_name}")
print(f" Workflow: {workflow_label}")
print(f" Blocks: {len(block_infos)}")
print(f"\nRun with: python {self.output}")
else:
print(f"Launching daggr app for {blocks_class_name} ({workflow_label} workflow)...")
tmp = tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False, prefix="daggr_")
tmp.write(code)
tmp.close()
logger.info(f"Generated temp script: {tmp.name}")
exec(compile(code, tmp.name, "exec"), {"__name__": "__main__"})
def _deploy_to_space(self, code):
import os
from pathlib import Path
from daggr.cli import _deploy
tmp = tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False, prefix="daggr_")
tmp.write(code)
tmp.close()
req_path = self.requirements
if not req_path:
req_file = tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False, prefix="daggr_req_")
req_file.write(DEFAULT_REQUIREMENTS)
req_file.close()
req_path = req_file.name
secrets = {}
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
secrets["HF_TOKEN"] = hf_token
deploy_name = self.deploy
deploy_org = None
if "/" in deploy_name:
deploy_org, deploy_name = deploy_name.rsplit("/", 1)
_deploy(
script_path=Path(tmp.name),
name=deploy_name,
title=None,
org=deploy_org,
private=self.private,
hardware=self.hardware,
secrets=secrets,
requirements_path=req_path,
dry_run=False,
)
def _get_workflow_resolve_code(self):
if self.workflow:
return f"_pipeline._blocks.get_workflow({self.workflow!r})"
elif self.trigger_inputs:
kwargs_str = ", ".join(f"{name!r}: True" for name in self.trigger_inputs)
return f"_pipeline._blocks.get_execution_blocks(**{{{kwargs_str}}})"
else:
return "_pipeline._blocks.get_execution_blocks()"
# ---------------------------------------------------------------------------
# Block analysis
# ---------------------------------------------------------------------------
def _get_block_info(exec_blocks):
block_infos = []
for name, block in exec_blocks.sub_blocks.items():
block_infos.append(
BlockInfo(
name=name,
class_name=block.__class__.__name__,
description=getattr(block, "description", "") or "",
inputs=list(block.inputs) if hasattr(block, "inputs") else [],
outputs=list(block.intermediate_outputs) if hasattr(block, "intermediate_outputs") else [],
)
)
return block_infos
def _filter_outputs(block_infos):
last_idx = len(block_infos) - 1
for i, info in enumerate(block_infos):
downstream_input_names = set()
for later_info in block_infos[i + 1 :]:
for inp in later_info.inputs:
if inp.name:
downstream_input_names.add(inp.name)
is_last = i == last_idx
info.outputs = [
out
for out in info.outputs
if out.name in downstream_input_names
or (is_last and (_contains_pil_image(out.type_hint) or not _is_internal_type(out.type_hint)))
]
def _classify_inputs(block_infos):
all_prior_outputs = {}
for info in block_infos:
user_inputs, port_connections, fixed_inputs = [], [], []
for inp in info.inputs:
if inp.name is None:
continue
if inp.name in all_prior_outputs:
port_connections.append((inp.name, all_prior_outputs[inp.name]))
elif _is_user_facing(inp):
user_inputs.append(inp)
else:
fixed_inputs.append(inp)
info.user_inputs = user_inputs
info.port_connections = port_connections
info.fixed_inputs = fixed_inputs
for out in info.outputs:
if out.name:
all_prior_outputs[out.name] = info.name
# ---------------------------------------------------------------------------
# Type helpers
# ---------------------------------------------------------------------------
def _get_type_name(type_hint):
if type_hint is None:
return None
if hasattr(type_hint, "__name__"):
return type_hint.__name__
if hasattr(type_hint, "__module__") and hasattr(type_hint, "__qualname__"):
return f"{type_hint.__module__}.{type_hint.__qualname__}"
return str(type_hint)
def _is_internal_type(type_hint):
if type_hint is None:
return False
type_name = _get_type_name(type_hint)
if type_name is None:
return False
if type_name in INTERNAL_TYPE_NAMES or type_name in INTERNAL_TYPE_FULL_NAMES:
return True
type_str = str(type_hint)
for full_name in INTERNAL_TYPE_FULL_NAMES:
if full_name in type_str:
return True
if type_str.startswith("dict[") or type_str == "dict":
return True
return False
def _contains_pil_image(type_hint):
from PIL import Image
if type_hint is Image.Image:
return True
args = typing.get_args(type_hint)
return any(_contains_pil_image(a) for a in args) if args else False
def _is_list_or_tuple_type(type_hint):
origin = typing.get_origin(type_hint)
if origin in (list, tuple):
return True
if origin is typing.Union or origin is types.UnionType:
return any(_is_list_or_tuple_type(a) for a in typing.get_args(type_hint))
return False
def _resolve_from_template(inp):
from ..modular_pipelines.modular_pipeline_utils import INPUT_PARAM_TEMPLATES
type_hint = inp.type_hint
default = inp.default
if inp.name in INPUT_PARAM_TEMPLATES:
tmpl = INPUT_PARAM_TEMPLATES[inp.name]
if type_hint is None:
type_hint = tmpl.get("type_hint")
if default is None:
default = tmpl.get("default")
return type_hint, default
def _is_user_facing(inp):
type_hint, _ = _resolve_from_template(inp)
if type_hint is not None:
if _contains_pil_image(type_hint):
return True
return not _is_internal_type(type_hint)
if inp.name in SLIDER_PARAMS:
return True
return False
def _sanitize_name(name):
sanitized = re.sub(r"[^a-zA-Z0-9_]", "_", name)
if sanitized and sanitized[0].isdigit():
sanitized = f"_{sanitized}"
return sanitized
# ---------------------------------------------------------------------------
# Gradio component code strings
# ---------------------------------------------------------------------------
def _type_hint_to_gradio(type_hint, param_name, default=None):
if param_name in DROPDOWN_PARAMS:
opts = DROPDOWN_PARAMS[param_name]
val = default if default is not None else opts.get("value")
return f'gr.Dropdown(label="{param_name}", choices={opts["choices"]!r}, value={val!r})'
if param_name in SLIDER_PARAMS:
opts = SLIDER_PARAMS[param_name]
val = default if default is not None else opts.get("minimum", 0)
return (
f'gr.Slider(label="{param_name}", value={val!r}, '
f"minimum={opts['minimum']}, maximum={opts['maximum']}, step={opts['step']})"
)
if type_hint is not None and _is_internal_type(type_hint):
return None
if type_hint is not None and _contains_pil_image(type_hint):
return f'gr.Image(label="{param_name}")'
if type_hint is str:
default_repr = f", value={default!r}" if default is not None else ""
return f'gr.Textbox(label="{param_name}", lines=1{default_repr})'
if type_hint is int:
val = f", value={default!r}" if default is not None else ""
return f'gr.Number(label="{param_name}", precision=0{val})'
if type_hint is float:
val = f", value={default!r}" if default is not None else ""
return f'gr.Number(label="{param_name}"{val})'
if type_hint is bool:
val = default if default is not None else False
return f'gr.Checkbox(label="{param_name}", value={val!r})'
if default is not None:
return f'gr.Textbox(label="{param_name}", value={default!r})'
return f'gr.Textbox(label="{param_name}")'
def _output_type_to_gradio(type_hint, param_name):
if type_hint is not None and _contains_pil_image(type_hint):
return f'gr.Image(label="{param_name}")'
if type_hint is str:
return f'gr.Textbox(label="{param_name}")'
if type_hint is int or type_hint is float:
return f'gr.Number(label="{param_name}")'
return f'gr.Textbox(label="{param_name}", visible=False)'
# ---------------------------------------------------------------------------
# Code generation
# ---------------------------------------------------------------------------
def _generate_code(block_infos, repo_id, blocks_class_name, workflow_label, workflow_resolve_code):
sections = []
# Pre-compute metadata
blocks_with_image_inputs = {}
blocks_with_parseable_inputs = {}
blocks_with_image_outputs = set()
port_conn_source = {}
for info in block_infos:
img_names, parse_names = set(), set()
for inp in info.inputs:
if inp.name is None:
continue
resolved_type, _ = _resolve_from_template(inp)
if resolved_type is not None and _contains_pil_image(resolved_type):
img_names.add(inp.name)
elif resolved_type is not None and _is_list_or_tuple_type(resolved_type):
parse_names.add(inp.name)
if img_names:
blocks_with_image_inputs[info.name] = img_names
if parse_names:
blocks_with_parseable_inputs[info.name] = parse_names
if any(out.type_hint is not None and _contains_pil_image(out.type_hint) for out in info.outputs):
blocks_with_image_outputs.add(info.name)
for conn_name, source_block in info.port_connections:
port_conn_source[(info.name, conn_name)] = source_block
needs_image_io = blocks_with_image_inputs or blocks_with_image_outputs
needs_parsing = bool(blocks_with_parseable_inputs)
# -- Header & helpers --
extra_imports = ""
if needs_parsing:
extra_imports += "import ast\n"
if needs_image_io:
extra_imports += "import tempfile\n"
sections.append(f'"""Daggr app for {blocks_class_name} ({workflow_label} workflow)')
sections.append("Generated by: diffusers-cli daggr")
sections.append('"""')
header = _HEADER_TEMPLATE.format(
extra_imports=extra_imports, repo_id=repr(repo_id), workflow_resolve_code=workflow_resolve_code
)
sections.extend(header.splitlines())
if needs_image_io:
sections.extend(_SAVE_IMAGE_TEMPLATE.splitlines())
# -- Block functions --
for info in block_infos:
fn_name = f"run_{_sanitize_name(info.name)}"
input_names = [inp.name for inp in info.inputs if inp.name is not None]
port_conn_names = {c for c, _ in info.port_connections}
has_image_out = info.name in blocks_with_image_outputs
body_lines = []
body_lines.append(" from diffusers.modular_pipelines.modular_pipeline import PipelineState")
body_lines.append("")
body_lines.append(" pipe, exec_blocks = _get_pipeline()")
body_lines.append(" state = PipelineState()")
if info.name in blocks_with_image_inputs:
body_lines.append(" from PIL import Image as PILImage")
body_lines.append("")
for img_name in blocks_with_image_inputs[info.name]:
body_lines.append(f" if {img_name} is not None and isinstance({img_name}, str):")
body_lines.append(f" {img_name} = PILImage.open({img_name})")
if info.name in blocks_with_parseable_inputs:
for parse_name in blocks_with_parseable_inputs[info.name]:
body_lines.append(f" if {parse_name} is not None and isinstance({parse_name}, str):")
body_lines.append(
f" {parse_name} = ast.literal_eval({parse_name}.strip()) if {parse_name}.strip() else None"
)
for n in input_names:
if n in port_conn_names:
source = port_conn_source[(info.name, n)]
body_lines.append(f' state.set("{n}", _tensor_store.get("{source}:{n}", {n}))')
else:
body_lines.append(f' state.set("{n}", {n})')
if info.sub_block_names:
for n in info.sub_block_names:
body_lines.append(f' _, state = exec_blocks.sub_blocks["{n}"](pipe, state)')
else:
body_lines.append(f' _, state = exec_blocks.sub_blocks["{info.name}"](pipe, state)')
for o in info.outputs:
if o.type_hint is not None and _is_internal_type(o.type_hint):
body_lines.append(f' _tensor_store["{info.name}:{o.name}"] = state.get("{o.name}")')
if len(info.outputs) == 0:
body_lines.append(" return None")
elif len(info.outputs) == 1:
out = info.outputs[0]
if has_image_out and _contains_pil_image(out.type_hint):
body_lines.append(f' return _save_image(state.get("{out.name}"))')
elif out.type_hint is not None and _is_internal_type(out.type_hint):
body_lines.append(f' return "{info.name}:{out.name}"')
else:
body_lines.append(f' return state.get("{out.name}")')
else:
parts = []
for o in info.outputs:
if has_image_out and o.type_hint is not None and _contains_pil_image(o.type_hint):
parts.append(f'_save_image(state.get("{o.name}"))')
elif o.type_hint is not None and _is_internal_type(o.type_hint):
parts.append(f'"{info.name}:{o.name}"')
else:
parts.append(f'state.get("{o.name}")')
body_lines.append(f" return {', '.join(parts)}")
body = "\n".join(body_lines)
sections.append(f"\n\ndef {fn_name}({', '.join(input_names)}):\n{body}\n")
# -- Node definitions --
sections.append("\n\n# -- Pipeline Blocks --")
node_var_names = {}
input_node_var_names = []
user_input_sources = {}
for info in block_infos:
var_name = f"{_sanitize_name(info.name)}_node"
node_var_names[info.name] = var_name
fn_name = f"run_{_sanitize_name(info.name)}"
display_name = info.name.replace("_", " ").replace(".", " > ").title()
for inp in info.user_inputs:
if inp.name in user_input_sources:
continue
resolved_type, resolved_default = _resolve_from_template(inp)
gradio_comp = _type_hint_to_gradio(resolved_type, inp.name, resolved_default)
if gradio_comp:
input_var = f"{_sanitize_name(inp.name)}_input"
sections.append(
f'\n{input_var} = InputNode("{inp.name}", ports={{\n "{inp.name}": {gradio_comp},\n}})'
)
input_node_var_names.append(input_var)
user_input_sources[inp.name] = input_var
input_entries = []
for inp in info.inputs:
if inp.name is None:
continue
connected = False
for conn_name, source_block in info.port_connections:
if conn_name == inp.name:
input_entries.append(f' "{inp.name}": {node_var_names[source_block]}.{inp.name},')
connected = True
break
if not connected:
if inp.name in user_input_sources:
input_entries.append(f' "{inp.name}": {user_input_sources[inp.name]}.{inp.name},')
elif inp.default is not None:
input_entries.append(f' "{inp.name}": {inp.default!r},')
else:
input_entries.append(f' "{inp.name}": None,')
output_entries = []
for out in info.outputs:
output_entries.append(f' "{out.name}": {_output_type_to_gradio(out.type_hint, out.name)},')
node_parts = [f"\n{var_name} = FnNode(", f" fn={fn_name},", f' name="{display_name}",']
if input_entries:
node_parts.append(" inputs={")
node_parts.extend(input_entries)
node_parts.append(" },")
if output_entries:
node_parts.append(" outputs={")
node_parts.extend(output_entries)
node_parts.append(" },")
node_parts.append(")")
sections.append("\n".join(node_parts))
# -- Graph launch with cache disabled --
all_node_vars = input_node_var_names + [node_var_names[info.name] for info in block_infos]
graph_name = f"{blocks_class_name} - {workflow_label}"
nodes_str = ", ".join(all_node_vars)
sections.append("")
sections.append("")
sections.append("# -- Graph --")
sections.extend(_CACHE_DISABLE_TEMPLATE.splitlines())
sections.append("")
sections.append(f'graph = Graph("{graph_name}", nodes=[{nodes_str}])')
sections.append("graph.launch()")
sections.append("")
return "\n".join(sections)

View File

@@ -16,7 +16,6 @@
from argparse import ArgumentParser
from .custom_blocks import CustomBlocksCommand
from .daggr_app import DaggrCommand
from .env import EnvironmentCommand
from .fp16_safetensors import FP16SafetensorsCommand
@@ -29,7 +28,6 @@ def main():
EnvironmentCommand.register_subcommand(commands_parser)
FP16SafetensorsCommand.register_subcommand(commands_parser)
CustomBlocksCommand.register_subcommand(commands_parser)
DaggrCommand.register_subcommand(commands_parser)
# Let's go
args = parser.parse_args()

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,59 +12,47 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import LTXVideoTransformer3DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
from ..testing_utils import (
BaseModelTesterConfig,
MemoryTesterMixin,
ModelTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
enable_full_determinism()
class LTXTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = LTXVideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
class LTXTransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return LTXVideoTransformer3DModel
@property
def dummy_input(self):
batch_size = 2
num_channels = 4
num_frames = 2
height = 16
width = 16
embedding_dim = 16
sequence_length = 16
def output_shape(self) -> tuple[int, int]:
return (512, 4)
hidden_states = torch.randn((batch_size, num_frames * height * width, num_channels)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
@property
def input_shape(self) -> tuple[int, int]:
return (512, 4)
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self):
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep,
"encoder_attention_mask": encoder_attention_mask,
"num_frames": num_frames,
"height": height,
"width": width,
}
@property
def input_shape(self):
return (512, 4)
@property
def output_shape(self):
return (512, 4)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 4,
"out_channels": 4,
"num_attention_heads": 2,
@@ -75,16 +62,57 @@ class LTXTransformerTests(ModelTesterMixin, unittest.TestCase):
"qk_norm": "rms_norm_across_heads",
"caption_channels": 16,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 2
num_channels = 4
num_frames = 2
height = 16
width = 16
embedding_dim = 16
sequence_length = 16
return {
"hidden_states": randn_tensor(
(batch_size, num_frames * height * width, num_channels),
generator=self.generator,
device=torch_device,
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
"encoder_attention_mask": torch.ones((batch_size, sequence_length)).bool().to(torch_device),
"num_frames": num_frames,
"height": height,
"width": width,
}
class TestLTXTransformer(LTXTransformerTesterConfig, ModelTesterMixin):
"""Core model tests for LTX Video Transformer."""
class TestLTXTransformerMemory(LTXTransformerTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for LTX Video Transformer."""
class TestLTXTransformerTraining(LTXTransformerTesterConfig, TrainingTesterMixin):
"""Training tests for LTX Video Transformer."""
def test_gradient_checkpointing_is_applied(self):
expected_set = {"LTXVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
super().test_gradient_checkpointing_is_applied(expected_set={"LTXVideoTransformer3DModel"})
class LTXTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = LTXVideoTransformer3DModel
class TestLTXTransformerCompile(LTXTransformerTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for LTX Video Transformer."""
def prepare_init_args_and_inputs_for_common(self):
return LTXTransformerTests().prepare_init_args_and_inputs_for_common()
# TODO: Add pretrained_model_name_or_path once a tiny LTX model is available on the Hub
# class TestLTXTransformerBitsAndBytes(LTXTransformerTesterConfig, BitsAndBytesTesterMixin):
# """BitsAndBytes quantization tests for LTX Video Transformer."""
# TODO: Add pretrained_model_name_or_path once a tiny LTX model is available on the Hub
# class TestLTXTransformerTorchAo(LTXTransformerTesterConfig, TorchAoTesterMixin):
# """TorchAo quantization tests for LTX Video Transformer."""

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,77 +12,49 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import pytest
import torch
from diffusers import LTX2VideoTransformer3DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
MemoryTesterMixin,
ModelTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
enable_full_determinism()
class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = LTX2VideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
class LTX2TransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return LTX2VideoTransformer3DModel
@property
def dummy_input(self):
# Common
batch_size = 2
def output_shape(self) -> tuple[int, int]:
return (512, 4)
# Video
num_frames = 2
num_channels = 4
height = 16
width = 16
@property
def input_shape(self) -> tuple[int, int]:
return (512, 4)
# Audio
audio_num_frames = 9
audio_num_channels = 2
num_mel_bins = 2
@property
def main_input_name(self) -> str:
return "hidden_states"
# Text
embedding_dim = 16
sequence_length = 16
hidden_states = torch.randn((batch_size, num_frames * height * width, num_channels)).to(torch_device)
audio_hidden_states = torch.randn((batch_size, audio_num_frames, audio_num_channels * num_mel_bins)).to(
torch_device
)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
audio_encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device)
timestep = torch.rand((batch_size,)).to(torch_device) * 1000
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self):
return {
"hidden_states": hidden_states,
"audio_hidden_states": audio_hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"audio_encoder_hidden_states": audio_encoder_hidden_states,
"timestep": timestep,
"encoder_attention_mask": encoder_attention_mask,
"num_frames": num_frames,
"height": height,
"width": width,
"audio_num_frames": audio_num_frames,
"fps": 25.0,
}
@property
def input_shape(self):
return (512, 4)
@property
def output_shape(self):
return (512, 4)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 4,
"out_channels": 4,
"patch_size": 1,
@@ -101,122 +72,80 @@ class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase):
"caption_channels": 16,
"rope_double_precision": False,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 2
num_frames = 2
num_channels = 4
height = 16
width = 16
audio_num_frames = 9
audio_num_channels = 2
num_mel_bins = 2
embedding_dim = 16
sequence_length = 16
return {
"hidden_states": randn_tensor(
(batch_size, num_frames * height * width, num_channels),
generator=self.generator,
device=torch_device,
),
"audio_hidden_states": randn_tensor(
(batch_size, audio_num_frames, audio_num_channels * num_mel_bins),
generator=self.generator,
device=torch_device,
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
),
"audio_encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
),
"timestep": (randn_tensor((batch_size,), generator=self.generator, device=torch_device).abs() * 1000),
"encoder_attention_mask": torch.ones((batch_size, sequence_length)).bool().to(torch_device),
"num_frames": num_frames,
"height": height,
"width": width,
"audio_num_frames": audio_num_frames,
"fps": 25.0,
}
class TestLTX2Transformer(LTX2TransformerTesterConfig, ModelTesterMixin):
"""Core model tests for LTX2 Video Transformer."""
class TestLTX2TransformerMemory(LTX2TransformerTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for LTX2 Video Transformer."""
class TestLTX2TransformerTraining(LTX2TransformerTesterConfig, TrainingTesterMixin):
"""Training tests for LTX2 Video Transformer."""
def test_gradient_checkpointing_is_applied(self):
expected_set = {"LTX2VideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
# def test_ltx2_consistency(self, seed=0, dtype=torch.float32):
# torch.manual_seed(seed)
# init_dict, _ = self.prepare_init_args_and_inputs_for_common()
# # Calculate dummy inputs in a custom manner to ensure compatibility with original code
# batch_size = 2
# num_frames = 9
# latent_frames = 2
# text_embedding_dim = 16
# text_seq_len = 16
# fps = 25.0
# sampling_rate = 16000.0
# hop_length = 160.0
# sigma = torch.rand((1,), generator=torch.manual_seed(seed), dtype=dtype, device="cpu") * 1000
# timestep = (sigma * torch.ones((batch_size,), dtype=dtype, device="cpu")).to(device=torch_device)
# num_channels = 4
# latent_height = 4
# latent_width = 4
# hidden_states = torch.randn(
# (batch_size, num_channels, latent_frames, latent_height, latent_width),
# generator=torch.manual_seed(seed),
# dtype=dtype,
# device="cpu",
# )
# # Patchify video latents (with patch_size (1, 1, 1))
# hidden_states = hidden_states.reshape(batch_size, -1, latent_frames, 1, latent_height, 1, latent_width, 1)
# hidden_states = hidden_states.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
# encoder_hidden_states = torch.randn(
# (batch_size, text_seq_len, text_embedding_dim),
# generator=torch.manual_seed(seed),
# dtype=dtype,
# device="cpu",
# )
# audio_num_channels = 2
# num_mel_bins = 2
# latent_length = int((sampling_rate / hop_length / 4) * (num_frames / fps))
# audio_hidden_states = torch.randn(
# (batch_size, audio_num_channels, latent_length, num_mel_bins),
# generator=torch.manual_seed(seed),
# dtype=dtype,
# device="cpu",
# )
# # Patchify audio latents
# audio_hidden_states = audio_hidden_states.transpose(1, 2).flatten(2, 3)
# audio_encoder_hidden_states = torch.randn(
# (batch_size, text_seq_len, text_embedding_dim),
# generator=torch.manual_seed(seed),
# dtype=dtype,
# device="cpu",
# )
# inputs_dict = {
# "hidden_states": hidden_states.to(device=torch_device),
# "audio_hidden_states": audio_hidden_states.to(device=torch_device),
# "encoder_hidden_states": encoder_hidden_states.to(device=torch_device),
# "audio_encoder_hidden_states": audio_encoder_hidden_states.to(device=torch_device),
# "timestep": timestep,
# "num_frames": latent_frames,
# "height": latent_height,
# "width": latent_width,
# "audio_num_frames": num_frames,
# "fps": 25.0,
# }
# model = self.model_class.from_pretrained(
# "diffusers-internal-dev/dummy-ltx2",
# subfolder="transformer",
# device_map="cpu",
# )
# # torch.manual_seed(seed)
# # model = self.model_class(**init_dict)
# model.to(torch_device)
# model.eval()
# with attention_backend("native"):
# with torch.no_grad():
# output = model(**inputs_dict)
# video_output, audio_output = output.to_tuple()
# self.assertIsNotNone(video_output)
# self.assertIsNotNone(audio_output)
# # input & output have to have the same shape
# video_expected_shape = (batch_size, latent_frames * latent_height * latent_width, num_channels)
# self.assertEqual(video_output.shape, video_expected_shape, "Video input and output shapes do not match")
# audio_expected_shape = (batch_size, latent_length, audio_num_channels * num_mel_bins)
# self.assertEqual(audio_output.shape, audio_expected_shape, "Audio input and output shapes do not match")
# # Check against expected slice
# # fmt: off
# video_expected_slice = torch.tensor([0.4783, 1.6954, -1.2092, 0.1762, 0.7801, 1.2025, -1.4525, -0.2721, 0.3354, 1.9144, -1.5546, 0.0831, 0.4391, 1.7012, -1.7373, -0.2676])
# audio_expected_slice = torch.tensor([-0.4236, 0.4750, 0.3901, -0.4339, -0.2782, 0.4357, 0.4526, -0.3927, -0.0980, 0.4870, 0.3964, -0.3169, -0.3974, 0.4408, 0.3809, -0.4692])
# # fmt: on
# video_output_flat = video_output.cpu().flatten().float()
# video_generated_slice = torch.cat([video_output_flat[:8], video_output_flat[-8:]])
# self.assertTrue(torch.allclose(video_generated_slice, video_expected_slice, atol=1e-4))
# audio_output_flat = audio_output.cpu().flatten().float()
# audio_generated_slice = torch.cat([audio_output_flat[:8], audio_output_flat[-8:]])
# self.assertTrue(torch.allclose(audio_generated_slice, audio_expected_slice, atol=1e-4))
super().test_gradient_checkpointing_is_applied(expected_set={"LTX2VideoTransformer3DModel"})
class LTX2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = LTX2VideoTransformer3DModel
class TestLTX2TransformerAttention(LTX2TransformerTesterConfig, AttentionTesterMixin):
"""Attention processor tests for LTX2 Video Transformer."""
def prepare_init_args_and_inputs_for_common(self):
return LTX2TransformerTests().prepare_init_args_and_inputs_for_common()
@pytest.mark.skip(
"LTX2Attention does not set is_cross_attention, so fuse_projections tries to fuse Q+K+V together even for cross-attention modules with different input dimensions."
)
def test_fuse_unfuse_qkv_projections(self, atol=1e-3, rtol=0):
pass
class TestLTX2TransformerCompile(LTX2TransformerTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for LTX2 Video Transformer."""
# TODO: Add pretrained_model_name_or_path once a tiny LTX2 model is available on the Hub
# class TestLTX2TransformerBitsAndBytes(LTX2TransformerTesterConfig, BitsAndBytesTesterMixin):
# """BitsAndBytes quantization tests for LTX2 Video Transformer."""
# TODO: Add pretrained_model_name_or_path once a tiny LTX2 model is available on the Hub
# class TestLTX2TransformerTorchAo(LTX2TransformerTesterConfig, TorchAoTesterMixin):
# """TorchAo quantization tests for LTX2 Video Transformer."""