Compare commits

..

4 Commits

Author SHA1 Message Date
yiyi@huggingface.co
a123e95ee2 up 2026-02-27 11:08:59 +00:00
yiyi@huggingface.co
944a478989 up 2026-02-27 11:08:16 +00:00
Jerry Song
40e96454f1 Fix LTX-2 image-to-video generation failure in two stages generation (#13187)
* Fix LTX-2 image-to-video generation failure in two stages generation

In LTX-2's two-stage image-to-video generation task, specifically after
the upsampling step, a shape mismatch occurs between the `latents` and
the `conditioning_mask`, which causes an error in function
`_create_noised_state`.

Fix it by creating the `conditioning_mask` based on the shape of the
`latents`.

* Add unit test for LTX-2 i2v two stages inference with upsampler

* Downscaling the upsampler in LTX-2 image-to-video unit test

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2026-02-27 00:55:01 -08:00
Varun Chawla
47455bd133 Fix Flash Attention 3 interface for new FA3 return format (#13173)
* Fix Flash Attention 3 interface compatibility for new FA3 versions

Newer versions of flash-attn (after Dao-AILab/flash-attention@ed20940)
no longer return lse by default from flash_attn_3_func. The function
now returns just the output tensor unless return_attn_probs=True is
passed.

Updated _wrapped_flash_attn_3 and _flash_varlen_attention_3 to pass
return_attn_probs and handle both old (always tuple) and new (tensor
or tuple) return formats gracefully.

Fixes #12022

* Simplify _wrapped_flash_attn_3 return unpacking

Since return_attn_probs=True is always passed, the result is
guaranteed to be a tuple. Remove the unnecessary isinstance guard.
2026-02-26 17:34:36 +05:30
8 changed files with 218 additions and 288 deletions

View File

@@ -332,49 +332,4 @@ Make your custom block work with Mellon's visual interface. See the [Mellon Cust
Browse the [Modular Diffusers Custom Blocks](https://huggingface.co/collections/diffusers/modular-diffusers-custom-blocks) collection for inspiration and ready-to-use blocks.
</hfoption>
</hfoptions>
## Dependencies
Declaring package dependencies in custom blocks prevents runtime import errors later on. Diffusers validates the dependencies and returns a warning if a package is missing or incompatible.
Set a `_requirements` attribute in your block class, mapping package names to version specifiers.
```py
from diffusers.modular_pipelines import PipelineBlock
class MyCustomBlock(PipelineBlock):
_requirements = {
"transformers": ">=4.44.0",
"sentencepiece": ">=0.2.0"
}
```
When there are blocks with different requirements, Diffusers merges their requirements.
```py
from diffusers.modular_pipelines import SequentialPipelineBlocks
class BlockA(PipelineBlock):
_requirements = {"transformers": ">=4.44.0"}
# ...
class BlockB(PipelineBlock):
_requirements = {"sentencepiece": ">=0.2.0"}
# ...
pipe = SequentialPipelineBlocks.from_blocks_dict({
"block_a": BlockA,
"block_b": BlockB,
})
```
When this block is saved with [`~ModularPipeline.save_pretrained`], the requirements are saved to the `modular_config.json` file. When this block is loaded, Diffusers checks each requirement against the current environment. If there is a mismatch or a package isn't found, Diffusers returns the following warning.
```md
# missing package
xyz-package was specified in the requirements but wasn't found in the current environment.
# version mismatch
xyz requirement 'specific-version' is not satisfied by the installed version 'actual-version'. Things might work unexpected.
```
</hfoptions>

View File

@@ -89,6 +89,8 @@ class CustomBlocksCommand(BaseDiffusersCLICommand):
# automap = self._create_automap(parent_class=parent_class, child_class=child_class)
# with open(CONFIG, "w") as f:
# json.dump(automap, f)
with open("requirements.txt", "w") as f:
f.write("")
def _choose_block(self, candidates, chosen=None):
for cls, base in candidates:

View File

@@ -733,7 +733,7 @@ def _wrapped_flash_attn_3(
) -> tuple[torch.Tensor, torch.Tensor]:
# Hardcoded for now because pytorch does not support tuple/int type hints
window_size = (-1, -1)
out, lse, *_ = flash_attn_3_func(
result = flash_attn_3_func(
q=q,
k=k,
v=v,
@@ -750,7 +750,9 @@ def _wrapped_flash_attn_3(
pack_gqa=pack_gqa,
deterministic=deterministic,
sm_margin=sm_margin,
return_attn_probs=True,
)
out, lse, *_ = result
lse = lse.permute(0, 2, 1)
return out, lse
@@ -2701,7 +2703,7 @@ def _flash_varlen_attention_3(
key_packed = torch.cat(key_valid, dim=0)
value_packed = torch.cat(value_valid, dim=0)
out, lse, *_ = flash_attn_3_varlen_func(
result = flash_attn_3_varlen_func(
q=query_packed,
k=key_packed,
v=value_packed,
@@ -2711,7 +2713,13 @@ def _flash_varlen_attention_3(
max_seqlen_k=max_seqlen_k,
softmax_scale=scale,
causal=is_causal,
return_attn_probs=return_lse,
)
if isinstance(result, tuple):
out, lse, *_ = result
else:
out = result
lse = None
out = out.unflatten(0, (batch_size, -1))
return (out, lse) if return_lse else out

View File

@@ -40,7 +40,6 @@ from .modular_pipeline_utils import (
InputParam,
InsertableDict,
OutputParam,
_validate_requirements,
combine_inputs,
combine_outputs,
format_components,
@@ -291,7 +290,6 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
config_name = "modular_config.json"
model_name = None
_requirements: dict[str, str] | None = None
_workflow_map = None
@classmethod
@@ -406,9 +404,6 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
)
if "requirements" in config and config["requirements"] is not None:
_ = _validate_requirements(config["requirements"])
class_ref = config["auto_map"][cls.__name__]
module_file, class_name = class_ref.split(".")
module_file = module_file + ".py"
@@ -433,13 +428,8 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "")
parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0]
auto_map = {f"{parent_module}": f"{module}.{cls_name}"}
self.register_to_config(auto_map=auto_map)
# resolve requirements
requirements = _validate_requirements(getattr(self, "_requirements", None))
if requirements:
self.register_to_config(requirements=requirements)
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
config = dict(self.config)
self._internal_dict = FrozenDict(config)
@@ -1250,14 +1240,6 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
expected_configs=self.expected_configs,
)
@property
def _requirements(self) -> dict[str, str]:
requirements = {}
for block_name, block in self.sub_blocks.items():
if getattr(block, "_requirements", None):
requirements[block_name] = block._requirements
return requirements
class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
"""

View File

@@ -22,12 +22,10 @@ from typing import Any, Literal, Type, Union, get_args, get_origin
import PIL.Image
import torch
from packaging.specifiers import InvalidSpecifier, SpecifierSet
from ..configuration_utils import ConfigMixin, FrozenDict
from ..loaders.single_file_utils import _is_single_file_path_or_url
from ..utils import DIFFUSERS_LOAD_ID_FIELDS, is_torch_available, logging
from ..utils.import_utils import _is_package_available
if is_torch_available():
@@ -50,13 +48,7 @@ This modular pipeline is composed of the following blocks:
## Model Components
{components_description} {configs_section}
## Input/Output Specification
### Inputs {inputs_description}
### Outputs {outputs_description}
{components_description} {configs_section} {io_specification_section}
"""
@@ -801,6 +793,46 @@ def format_output_params(output_params, indent_level=4, max_line_length=115):
return format_params(output_params, "Outputs", indent_level, max_line_length)
def format_params_markdown(params, header="Inputs"):
"""Format a list of InputParam or OutputParam objects as a markdown bullet-point list.
Suitable for model cards rendered on Hugging Face Hub.
Args:
params: list of InputParam or OutputParam objects to format
header: Header text (e.g. "Inputs" or "Outputs")
Returns:
A formatted markdown string, or empty string if params is empty.
"""
if not params:
return ""
def get_type_str(type_hint):
if isinstance(type_hint, UnionType) or get_origin(type_hint) is Union:
type_strs = [t.__name__ if hasattr(t, "__name__") else str(t) for t in get_args(type_hint)]
return " | ".join(type_strs)
return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint)
lines = [f"**{header}:**\n"]
for param in params:
type_str = get_type_str(param.type_hint) if param.type_hint != Any else ""
name = f"**{param.kwargs_type}" if param.name is None and param.kwargs_type is not None else param.name
param_str = f"- `{name}` (`{type_str}`"
if hasattr(param, "required") and not param.required:
param_str += ", *optional*"
if param.default is not None:
param_str += f", defaults to `{param.default}`"
param_str += ")"
desc = param.description if param.description else "No description provided"
param_str += f": {desc}"
lines.append(param_str)
return "\n".join(lines)
def format_components(components, indent_level=4, max_line_length=115, add_empty_lines=True):
"""Format a list of ComponentSpec objects into a readable string representation.
@@ -974,89 +1006,6 @@ def make_doc_string(
return output
def _validate_requirements(reqs):
if reqs is None:
normalized_reqs = {}
else:
if not isinstance(reqs, dict):
raise ValueError(
"Requirements must be provided as a dictionary mapping package names to version specifiers."
)
normalized_reqs = _normalize_requirements(reqs)
if not normalized_reqs:
return {}
final: dict[str, str] = {}
for req, specified_ver in normalized_reqs.items():
req_available, req_actual_ver = _is_package_available(req)
if not req_available:
logger.warning(f"{req} was specified in the requirements but wasn't found in the current environment.")
if specified_ver:
try:
specifier = SpecifierSet(specified_ver)
except InvalidSpecifier as err:
raise ValueError(f"Requirement specifier '{specified_ver}' for {req} is invalid.") from err
if req_actual_ver == "N/A":
logger.warning(
f"Version of {req} could not be determined to validate requirement '{specified_ver}'. Things might work unexpected."
)
elif not specifier.contains(req_actual_ver, prereleases=True):
logger.warning(
f"{req} requirement '{specified_ver}' is not satisfied by the installed version {req_actual_ver}. Things might work unexpected."
)
final[req] = specified_ver
return final
def _normalize_requirements(reqs):
if not reqs:
return {}
normalized: "OrderedDict[str, str]" = OrderedDict()
def _accumulate(mapping: dict[str, Any]):
for pkg, spec in mapping.items():
if isinstance(spec, dict):
# This is recursive because blocks are composable. This way, we can merge requirements
# from multiple blocks.
_accumulate(spec)
continue
pkg_name = str(pkg).strip()
if not pkg_name:
raise ValueError("Requirement package name cannot be empty.")
spec_str = "" if spec is None else str(spec).strip()
if spec_str and not spec_str.startswith(("<", ">", "=", "!", "~")):
spec_str = f"=={spec_str}"
existing_spec = normalized.get(pkg_name)
if existing_spec is not None:
if not existing_spec and spec_str:
normalized[pkg_name] = spec_str
elif existing_spec and spec_str and existing_spec != spec_str:
try:
combined_spec = SpecifierSet(",".join(filter(None, [existing_spec, spec_str])))
except InvalidSpecifier:
logger.warning(
f"Conflicting requirements for '{pkg_name}' detected: '{existing_spec}' vs '{spec_str}'. Keeping '{existing_spec}'."
)
else:
normalized[pkg_name] = str(combined_spec)
continue
normalized[pkg_name] = spec_str
_accumulate(reqs)
return normalized
def combine_inputs(*named_input_lists: list[tuple[str, list[InputParam]]]) -> list[InputParam]:
"""
Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if current
@@ -1140,8 +1089,7 @@ def generate_modular_model_card_content(blocks) -> dict[str, Any]:
- blocks_description: Detailed architecture of blocks
- components_description: List of required components
- configs_section: Configuration parameters section
- inputs_description: Input parameters specification
- outputs_description: Output parameters specification
- io_specification_section: Input/Output specification (per-workflow or unified)
- trigger_inputs_section: Conditional execution information
- tags: List of relevant tags for the model card
"""
@@ -1194,63 +1142,74 @@ def generate_modular_model_card_content(blocks) -> dict[str, Any]:
if configs_description:
configs_section = f"\n\n## Configuration Parameters\n\n{configs_description}"
inputs = blocks.inputs
outputs = blocks.outputs
# Branch on whether workflows are defined
has_workflows = getattr(blocks, "_workflow_map", None) is not None
# format inputs as markdown list
inputs_parts = []
required_inputs = [inp for inp in inputs if inp.required]
optional_inputs = [inp for inp in inputs if not inp.required]
if has_workflows:
# Per-workflow I/O sections
workflow_map = blocks._workflow_map
parts = []
if required_inputs:
inputs_parts.append("**Required:**\n")
for inp in required_inputs:
if hasattr(inp.type_hint, "__name__"):
type_str = inp.type_hint.__name__
elif inp.type_hint is not None:
type_str = str(inp.type_hint).replace("typing.", "")
else:
type_str = "Any"
desc = inp.description or "No description provided"
inputs_parts.append(f"- `{inp.name}` (`{type_str}`): {desc}")
# If blocks overrides outputs (e.g. to return just "images" instead of all intermediates),
# use that as the shared output for all workflows
blocks_outputs = blocks.outputs
blocks_intermediate = getattr(blocks, "intermediate_outputs", None)
shared_outputs = (
blocks_outputs if blocks_intermediate is not None and blocks_outputs != blocks_intermediate else None
)
if optional_inputs:
if required_inputs:
inputs_parts.append("")
inputs_parts.append("**Optional:**\n")
for inp in optional_inputs:
if hasattr(inp.type_hint, "__name__"):
type_str = inp.type_hint.__name__
elif inp.type_hint is not None:
type_str = str(inp.type_hint).replace("typing.", "")
else:
type_str = "Any"
desc = inp.description or "No description provided"
default_str = f", default: `{inp.default}`" if inp.default is not None else ""
inputs_parts.append(f"- `{inp.name}` (`{type_str}`){default_str}: {desc}")
# Summary section using existing format_workflow
parts.append("## Supported Workflows\n")
parts.append(format_workflow(workflow_map))
parts.append("")
inputs_description = "\n".join(inputs_parts) if inputs_parts else "No specific inputs defined."
# Per-workflow details
for wf_name, trigger_inputs in workflow_map.items():
trigger_input_names = set(trigger_inputs.keys())
try:
workflow_blocks = blocks.get_workflow(wf_name)
except Exception:
parts.append(f"<details>\n<summary><strong>{wf_name}</strong></summary>\n")
parts.append(f"> **Trigger inputs**: {', '.join(f'`{t}`' for t in trigger_input_names)}\n")
parts.append("*Could not resolve workflow blocks.*\n")
parts.append("</details>\n")
continue
# format outputs as markdown list
outputs_parts = []
for out in outputs:
if hasattr(out.type_hint, "__name__"):
type_str = out.type_hint.__name__
elif out.type_hint is not None:
type_str = str(out.type_hint).replace("typing.", "")
else:
type_str = "Any"
desc = out.description or "No description provided"
outputs_parts.append(f"- `{out.name}` (`{type_str}`): {desc}")
wf_inputs = workflow_blocks.inputs
wf_outputs = shared_outputs if shared_outputs is not None else workflow_blocks.outputs
outputs_description = "\n".join(outputs_parts) if outputs_parts else "Standard pipeline outputs."
parts.append(f"<details>\n<summary><strong>{wf_name}</strong></summary>\n")
parts.append(f"> **Trigger inputs**: {', '.join(f'`{t}`' for t in trigger_input_names)}\n")
trigger_inputs_section = ""
if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
trigger_inputs_list = sorted([t for t in blocks.trigger_inputs if t is not None])
if trigger_inputs_list:
trigger_inputs_str = ", ".join(f"`{t}`" for t in trigger_inputs_list)
trigger_inputs_section = f"""
inputs_str = format_params_markdown(wf_inputs, "Inputs")
parts.append(inputs_str if inputs_str else "No specific inputs defined.")
parts.append("")
outputs_str = format_params_markdown(wf_outputs, "Outputs")
parts.append(outputs_str if outputs_str else "No specific outputs defined.")
parts.append("")
parts.append("</details>\n")
io_specification_section = "\n".join(parts)
# Suppress trigger_inputs_section when workflows are shown (it's redundant)
trigger_inputs_section = ""
else:
# Unified I/O section (original behavior)
inputs = blocks.inputs
outputs = blocks.outputs
inputs_str = format_params_markdown(inputs, "Inputs")
outputs_str = format_params_markdown(outputs, "Outputs")
inputs_description = inputs_str if inputs_str else "No specific inputs defined."
outputs_description = outputs_str if outputs_str else "Standard pipeline outputs."
io_specification_section = f"## Input/Output Specification\n\n{inputs_description}\n\n{outputs_description}"
trigger_inputs_section = ""
if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
trigger_inputs_list = sorted([t for t in blocks.trigger_inputs if t is not None])
if trigger_inputs_list:
trigger_inputs_str = ", ".join(f"`{t}`" for t in trigger_inputs_list)
trigger_inputs_section = f"""
### Conditional Execution
This pipeline contains blocks that are selected at runtime based on inputs:
@@ -1263,7 +1222,18 @@ This pipeline contains blocks that are selected at runtime based on inputs:
if hasattr(blocks, "model_name") and blocks.model_name:
tags.append(blocks.model_name)
if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
if has_workflows:
# Derive tags from workflow names
workflow_names = set(blocks._workflow_map.keys())
if any("inpainting" in wf for wf in workflow_names):
tags.append("inpainting")
if any("image2image" in wf for wf in workflow_names):
tags.append("image-to-image")
if any("controlnet" in wf for wf in workflow_names):
tags.append("controlnet")
if any("text2image" in wf for wf in workflow_names):
tags.append("text-to-image")
elif hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
triggers = blocks.trigger_inputs
if any(t in triggers for t in ["mask", "mask_image"]):
tags.append("inpainting")
@@ -1291,8 +1261,7 @@ This pipeline uses a {block_count}-block architecture that can be customized and
"blocks_description": blocks_description,
"components_description": components_description,
"configs_section": configs_section,
"inputs_description": inputs_description,
"outputs_description": outputs_description,
"io_specification_section": io_specification_section,
"trigger_inputs_section": trigger_inputs_section,
"tags": tags,
}

View File

@@ -699,9 +699,13 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
mask_shape = (batch_size, 1, num_frames, height, width)
if latents is not None:
conditioning_mask = latents.new_zeros(mask_shape)
conditioning_mask[:, :, 0] = 1.0
if latents.ndim == 5:
# conditioning_mask needs to the same shape as latents in two stages generation.
batch_size, _, num_frames, height, width = latents.shape
mask_shape = (batch_size, 1, num_frames, height, width)
conditioning_mask = latents.new_zeros(mask_shape)
conditioning_mask[:, :, 0] = 1.0
latents = self._normalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
@@ -710,6 +714,9 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
latents = self._pack_latents(
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
)
else:
conditioning_mask = latents.new_zeros(mask_shape)
conditioning_mask[:, :, 0] = 1.0
conditioning_mask = self._pack_latents(
conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
).squeeze(-1)

View File

@@ -1,6 +1,4 @@
import gc
import json
import os
import tempfile
from typing import Callable
@@ -10,7 +8,6 @@ import torch
import diffusers
from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks
from diffusers.guiders import ClassifierFreeGuidance
from diffusers.modular_pipelines import SequentialPipelineBlocks
from diffusers.modular_pipelines.modular_pipeline_utils import (
ComponentSpec,
ConfigSpec,
@@ -20,13 +17,7 @@ from diffusers.modular_pipelines.modular_pipeline_utils import (
)
from diffusers.utils import logging
from ..testing_utils import (
CaptureLogger,
backend_empty_cache,
numpy_cosine_similarity_distance,
require_accelerator,
torch_device,
)
from ..testing_utils import backend_empty_cache, numpy_cosine_similarity_distance, require_accelerator, torch_device
class ModularPipelineTesterMixin:
@@ -409,56 +400,6 @@ class ModularGuiderTesterMixin:
assert max_diff > expected_max_diff, "Output with CFG must be different from normal inference"
class TestCustomBlockRequirements:
def get_dummy_block_pipe(self):
class DummyBlockOne:
# keep two arbitrary deps so that we can test warnings.
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
# keep two dependencies that will be available during testing.
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
pipe = SequentialPipelineBlocks.from_blocks_dict(
{"dummy_block_one": DummyBlockOne, "dummy_block_two": DummyBlockTwo}
)
return pipe
def test_custom_requirements_save_load(self):
pipe = self.get_dummy_block_pipe()
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir)
config_path = os.path.join(tmpdir, "modular_config.json")
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
requirements = config["requirements"]
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == requirements
def test_warnings(self):
pipe = self.get_dummy_block_pipe()
with tempfile.TemporaryDirectory() as tmpdir:
logger = logging.get_logger("diffusers.modular_pipelines.modular_pipeline_utils")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pipe.save_pretrained(tmpdir)
template = "{req} was specified in the requirements but wasn't found in the current environment"
msg_xyz = template.format(req="xyz")
msg_abc = template.format(req="abc")
assert msg_xyz in str(cap_logger.out)
assert msg_abc in str(cap_logger.out)
class TestModularModelCardContent:
def create_mock_block(self, name="TestBlock", description="Test block description"):
class MockBlock:
@@ -513,8 +454,7 @@ class TestModularModelCardContent:
"blocks_description",
"components_description",
"configs_section",
"inputs_description",
"outputs_description",
"io_specification_section",
"trigger_inputs_section",
"tags",
]
@@ -611,18 +551,19 @@ class TestModularModelCardContent:
blocks = self.create_mock_blocks(inputs=inputs)
content = generate_modular_model_card_content(blocks)
assert "**Required:**" in content["inputs_description"]
assert "**Optional:**" in content["inputs_description"]
assert "prompt" in content["inputs_description"]
assert "num_steps" in content["inputs_description"]
assert "default: `50`" in content["inputs_description"]
io_section = content["io_specification_section"]
assert "**Inputs:**" in io_section
assert "prompt" in io_section
assert "num_steps" in io_section
assert "*optional*" in io_section
assert "defaults to `50`" in io_section
def test_inputs_description_empty(self):
"""Test handling of pipelines without specific inputs."""
blocks = self.create_mock_blocks(inputs=[])
content = generate_modular_model_card_content(blocks)
assert "No specific inputs defined" in content["inputs_description"]
assert "No specific inputs defined" in content["io_specification_section"]
def test_outputs_description_formatting(self):
"""Test that outputs are correctly formatted."""
@@ -632,15 +573,16 @@ class TestModularModelCardContent:
blocks = self.create_mock_blocks(outputs=outputs)
content = generate_modular_model_card_content(blocks)
assert "images" in content["outputs_description"]
assert "Generated images" in content["outputs_description"]
io_section = content["io_specification_section"]
assert "images" in io_section
assert "Generated images" in io_section
def test_outputs_description_empty(self):
"""Test handling of pipelines without specific outputs."""
blocks = self.create_mock_blocks(outputs=[])
content = generate_modular_model_card_content(blocks)
assert "Standard pipeline outputs" in content["outputs_description"]
assert "Standard pipeline outputs" in content["io_specification_section"]
def test_trigger_inputs_section_with_triggers(self):
"""Test that trigger inputs section is generated when present."""

View File

@@ -24,7 +24,8 @@ from diffusers import (
LTX2ImageToVideoPipeline,
LTX2VideoTransformer3DModel,
)
from diffusers.pipelines.ltx2 import LTX2TextConnectors
from diffusers.pipelines.ltx2 import LTX2LatentUpsamplePipeline, LTX2TextConnectors
from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
from ...testing_utils import enable_full_determinism
@@ -174,6 +175,15 @@ class LTX2ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
return components
def get_dummy_upsample_component(self, in_channels=4, mid_channels=32, num_blocks_per_stage=1):
upsampler = LTX2LatentUpsamplerModel(
in_channels=in_channels,
mid_channels=mid_channels,
num_blocks_per_stage=num_blocks_per_stage,
)
return upsampler
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
@@ -287,5 +297,60 @@ class LTX2ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4)
assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4)
def test_two_stages_inference_with_upsampler(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["output_type"] = "latent"
first_stage_output = pipe(**inputs)
video_latent = first_stage_output.frames
audio_latent = first_stage_output.audio
self.assertEqual(video_latent.shape, (1, 4, 3, 16, 16))
self.assertEqual(audio_latent.shape, (1, 2, 5, 2))
self.assertEqual(audio_latent.shape[1], components["vocoder"].config.out_channels)
upsampler = self.get_dummy_upsample_component(in_channels=video_latent.shape[1])
upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=upsampler)
upscaled_video_latent = upsample_pipe(latents=video_latent, output_type="latent", return_dict=False)[0]
self.assertEqual(upscaled_video_latent.shape, (1, 4, 3, 32, 32))
inputs["latents"] = upscaled_video_latent
inputs["audio_latents"] = audio_latent
inputs["output_type"] = "pt"
second_stage_output = pipe(**inputs)
video = second_stage_output.frames
audio = second_stage_output.audio
self.assertEqual(video.shape, (1, 5, 3, 64, 64))
self.assertEqual(audio.shape[0], 1)
self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels)
# fmt: off
expected_video_slice = torch.tensor(
[
0.4497, 0.6757, 0.4219, 0.7686, 0.4525, 0.6483, 0.3969, 0.7404, 0.3541, 0.3039, 0.4592, 0.3521, 0.3665, 0.2785, 0.3336, 0.3079
]
)
expected_audio_slice = torch.tensor(
[
0.0271, 0.0492, 0.1249, 0.1126, 0.1661, 0.1060, 0.1717, 0.0944, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739
]
)
# fmt: on
video = video.flatten()
audio = audio.flatten()
generated_video_slice = torch.cat([video[:8], video[-8:]])
generated_audio_slice = torch.cat([audio[:8], audio[-8:]])
assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4)
assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4)
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=2e-2)