Compare commits

...

5 Commits

Author SHA1 Message Date
yiyi@huggingface.co
e34f085d2e update 2026-02-28 06:10:38 +00:00
yiyi@huggingface.co
a123e95ee2 up 2026-02-27 11:08:59 +00:00
yiyi@huggingface.co
944a478989 up 2026-02-27 11:08:16 +00:00
Jerry Song
40e96454f1 Fix LTX-2 image-to-video generation failure in two stages generation (#13187)
* Fix LTX-2 image-to-video generation failure in two stages generation

In LTX-2's two-stage image-to-video generation task, specifically after
the upsampling step, a shape mismatch occurs between the `latents` and
the `conditioning_mask`, which causes an error in function
`_create_noised_state`.

Fix it by creating the `conditioning_mask` based on the shape of the
`latents`.

* Add unit test for LTX-2 i2v two stages inference with upsampler

* Downscaling the upsampler in LTX-2 image-to-video unit test

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2026-02-27 00:55:01 -08:00
Varun Chawla
47455bd133 Fix Flash Attention 3 interface for new FA3 return format (#13173)
* Fix Flash Attention 3 interface compatibility for new FA3 versions

Newer versions of flash-attn (after Dao-AILab/flash-attention@ed20940)
no longer return lse by default from flash_attn_3_func. The function
now returns just the output tensor unless return_attn_probs=True is
passed.

Updated _wrapped_flash_attn_3 and _flash_varlen_attention_3 to pass
return_attn_probs and handle both old (always tuple) and new (tensor
or tuple) return formats gracefully.

Fixes #12022

* Simplify _wrapped_flash_attn_3 return unpacking

Since return_attn_probs=True is always passed, the result is
guaranteed to be a tuple. Remove the unnecessary isinstance guard.
2026-02-26 17:34:36 +05:30
7 changed files with 229 additions and 85 deletions

View File

@@ -733,7 +733,7 @@ def _wrapped_flash_attn_3(
) -> tuple[torch.Tensor, torch.Tensor]:
# Hardcoded for now because pytorch does not support tuple/int type hints
window_size = (-1, -1)
out, lse, *_ = flash_attn_3_func(
result = flash_attn_3_func(
q=q,
k=k,
v=v,
@@ -750,7 +750,9 @@ def _wrapped_flash_attn_3(
pack_gqa=pack_gqa,
deterministic=deterministic,
sm_margin=sm_margin,
return_attn_probs=True,
)
out, lse, *_ = result
lse = lse.permute(0, 2, 1)
return out, lse
@@ -2701,7 +2703,7 @@ def _flash_varlen_attention_3(
key_packed = torch.cat(key_valid, dim=0)
value_packed = torch.cat(value_valid, dim=0)
out, lse, *_ = flash_attn_3_varlen_func(
result = flash_attn_3_varlen_func(
q=query_packed,
k=key_packed,
v=value_packed,
@@ -2711,7 +2713,13 @@ def _flash_varlen_attention_3(
max_seqlen_k=max_seqlen_k,
softmax_scale=scale,
causal=is_causal,
return_attn_probs=return_lse,
)
if isinstance(result, tuple):
out, lse, *_ = result
else:
out = result
lse = None
out = out.unflatten(0, (batch_size, -1))
return (out, lse) if return_lse else out

View File

@@ -1836,6 +1836,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
create_pr = kwargs.pop("create_pr", False)
token = kwargs.pop("token", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
update_model_card = kwargs.pop("update_model_card", False)
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
# Generate modular pipeline card content
@@ -1848,6 +1849,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
is_pipeline=True,
model_description=MODULAR_MODEL_CARD_TEMPLATE.format(**card_content),
is_modular=True,
update_model_card=update_model_card,
)
model_card = populate_model_card(model_card, tags=card_content["tags"])

View File

@@ -50,11 +50,7 @@ This modular pipeline is composed of the following blocks:
{components_description} {configs_section}
## Input/Output Specification
### Inputs {inputs_description}
### Outputs {outputs_description}
{io_specification_section}
"""
@@ -799,6 +795,46 @@ def format_output_params(output_params, indent_level=4, max_line_length=115):
return format_params(output_params, "Outputs", indent_level, max_line_length)
def format_params_markdown(params, header="Inputs"):
"""Format a list of InputParam or OutputParam objects as a markdown bullet-point list.
Suitable for model cards rendered on Hugging Face Hub.
Args:
params: list of InputParam or OutputParam objects to format
header: Header text (e.g. "Inputs" or "Outputs")
Returns:
A formatted markdown string, or empty string if params is empty.
"""
if not params:
return ""
def get_type_str(type_hint):
if isinstance(type_hint, UnionType) or get_origin(type_hint) is Union:
type_strs = [t.__name__ if hasattr(t, "__name__") else str(t) for t in get_args(type_hint)]
return " | ".join(type_strs)
return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint)
lines = [f"**{header}:**\n"] if header else []
for param in params:
type_str = get_type_str(param.type_hint) if param.type_hint != Any else ""
name = f"**{param.kwargs_type}" if param.name is None and param.kwargs_type is not None else param.name
param_str = f"- `{name}` (`{type_str}`"
if hasattr(param, "required") and not param.required:
param_str += ", *optional*"
if param.default is not None:
param_str += f", defaults to `{param.default}`"
param_str += ")"
desc = param.description if param.description else "No description provided"
param_str += f": {desc}"
lines.append(param_str)
return "\n".join(lines)
def format_components(components, indent_level=4, max_line_length=115, add_empty_lines=True):
"""Format a list of ComponentSpec objects into a readable string representation.
@@ -1055,8 +1091,7 @@ def generate_modular_model_card_content(blocks) -> dict[str, Any]:
- blocks_description: Detailed architecture of blocks
- components_description: List of required components
- configs_section: Configuration parameters section
- inputs_description: Input parameters specification
- outputs_description: Output parameters specification
- io_specification_section: Input/Output specification (per-workflow or unified)
- trigger_inputs_section: Conditional execution information
- tags: List of relevant tags for the model card
"""
@@ -1075,15 +1110,6 @@ def generate_modular_model_card_content(blocks) -> dict[str, Any]:
if block_desc:
blocks_desc_parts.append(f" - {block_desc}")
# add sub-blocks if any
if hasattr(block, "sub_blocks") and block.sub_blocks:
for sub_name, sub_block in block.sub_blocks.items():
sub_class = sub_block.__class__.__name__
sub_desc = sub_block.description.split("\n")[0] if getattr(sub_block, "description", "") else ""
blocks_desc_parts.append(f" - *{sub_name}*: `{sub_class}`")
if sub_desc:
blocks_desc_parts.append(f" - {sub_desc}")
blocks_description = "\n".join(blocks_desc_parts) if blocks_desc_parts else "No blocks defined."
components = getattr(blocks, "expected_components", [])
@@ -1109,63 +1135,76 @@ def generate_modular_model_card_content(blocks) -> dict[str, Any]:
if configs_description:
configs_section = f"\n\n## Configuration Parameters\n\n{configs_description}"
inputs = blocks.inputs
outputs = blocks.outputs
# Branch on whether workflows are defined
has_workflows = getattr(blocks, "_workflow_map", None) is not None
# format inputs as markdown list
inputs_parts = []
required_inputs = [inp for inp in inputs if inp.required]
optional_inputs = [inp for inp in inputs if not inp.required]
if has_workflows:
workflow_map = blocks._workflow_map
parts = []
if required_inputs:
inputs_parts.append("**Required:**\n")
for inp in required_inputs:
if hasattr(inp.type_hint, "__name__"):
type_str = inp.type_hint.__name__
elif inp.type_hint is not None:
type_str = str(inp.type_hint).replace("typing.", "")
else:
type_str = "Any"
desc = inp.description or "No description provided"
inputs_parts.append(f"- `{inp.name}` (`{type_str}`): {desc}")
# If blocks overrides outputs (e.g. to return just "images" instead of all intermediates),
# use that as the shared output for all workflows
blocks_outputs = blocks.outputs
blocks_intermediate = getattr(blocks, "intermediate_outputs", None)
shared_outputs = (
blocks_outputs if blocks_intermediate is not None and blocks_outputs != blocks_intermediate else None
)
if optional_inputs:
if required_inputs:
inputs_parts.append("")
inputs_parts.append("**Optional:**\n")
for inp in optional_inputs:
if hasattr(inp.type_hint, "__name__"):
type_str = inp.type_hint.__name__
elif inp.type_hint is not None:
type_str = str(inp.type_hint).replace("typing.", "")
else:
type_str = "Any"
desc = inp.description or "No description provided"
default_str = f", default: `{inp.default}`" if inp.default is not None else ""
inputs_parts.append(f"- `{inp.name}` (`{type_str}`){default_str}: {desc}")
parts.append("## Workflow Input Specification\n")
inputs_description = "\n".join(inputs_parts) if inputs_parts else "No specific inputs defined."
# Per-workflow details: show trigger inputs with full param descriptions
for wf_name, trigger_inputs in workflow_map.items():
trigger_input_names = set(trigger_inputs.keys())
try:
workflow_blocks = blocks.get_workflow(wf_name)
except Exception:
parts.append(f"<details>\n<summary><strong>{wf_name}</strong></summary>\n")
parts.append("*Could not resolve workflow blocks.*\n")
parts.append("</details>\n")
continue
# format outputs as markdown list
outputs_parts = []
for out in outputs:
if hasattr(out.type_hint, "__name__"):
type_str = out.type_hint.__name__
elif out.type_hint is not None:
type_str = str(out.type_hint).replace("typing.", "")
else:
type_str = "Any"
desc = out.description or "No description provided"
outputs_parts.append(f"- `{out.name}` (`{type_str}`): {desc}")
wf_inputs = workflow_blocks.inputs
# Show only trigger inputs with full parameter descriptions
trigger_params = [p for p in wf_inputs if p.name in trigger_input_names]
outputs_description = "\n".join(outputs_parts) if outputs_parts else "Standard pipeline outputs."
parts.append(f"<details>\n<summary><strong>{wf_name}</strong></summary>\n")
trigger_inputs_section = ""
if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
trigger_inputs_list = sorted([t for t in blocks.trigger_inputs if t is not None])
if trigger_inputs_list:
trigger_inputs_str = ", ".join(f"`{t}`" for t in trigger_inputs_list)
trigger_inputs_section = f"""
inputs_str = format_params_markdown(trigger_params, header=None)
parts.append(inputs_str if inputs_str else "No additional inputs required.")
parts.append("")
parts.append("</details>\n")
# Common Inputs & Outputs section (like non-workflow pipelines)
all_inputs = blocks.inputs
all_outputs = shared_outputs if shared_outputs is not None else blocks.outputs
inputs_str = format_params_markdown(all_inputs, "Inputs")
outputs_str = format_params_markdown(all_outputs, "Outputs")
inputs_description = inputs_str if inputs_str else "No specific inputs defined."
outputs_description = outputs_str if outputs_str else "Standard pipeline outputs."
parts.append(f"\n## Input/Output Specification\n\n{inputs_description}\n\n{outputs_description}")
io_specification_section = "\n".join(parts)
# Suppress trigger_inputs_section when workflows are shown (it's redundant)
trigger_inputs_section = ""
else:
# Unified I/O section (original behavior)
inputs = blocks.inputs
outputs = blocks.outputs
inputs_str = format_params_markdown(inputs, "Inputs")
outputs_str = format_params_markdown(outputs, "Outputs")
inputs_description = inputs_str if inputs_str else "No specific inputs defined."
outputs_description = outputs_str if outputs_str else "Standard pipeline outputs."
io_specification_section = f"## Input/Output Specification\n\n{inputs_description}\n\n{outputs_description}"
trigger_inputs_section = ""
if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
trigger_inputs_list = sorted([t for t in blocks.trigger_inputs if t is not None])
if trigger_inputs_list:
trigger_inputs_str = ", ".join(f"`{t}`" for t in trigger_inputs_list)
trigger_inputs_section = f"""
### Conditional Execution
This pipeline contains blocks that are selected at runtime based on inputs:
@@ -1178,7 +1217,18 @@ This pipeline contains blocks that are selected at runtime based on inputs:
if hasattr(blocks, "model_name") and blocks.model_name:
tags.append(blocks.model_name)
if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
if has_workflows:
# Derive tags from workflow names
workflow_names = set(blocks._workflow_map.keys())
if any("inpainting" in wf for wf in workflow_names):
tags.append("inpainting")
if any("image2image" in wf for wf in workflow_names):
tags.append("image-to-image")
if any("controlnet" in wf for wf in workflow_names):
tags.append("controlnet")
if any("text2image" in wf for wf in workflow_names):
tags.append("text-to-image")
elif hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
triggers = blocks.trigger_inputs
if any(t in triggers for t in ["mask", "mask_image"]):
tags.append("inpainting")
@@ -1206,8 +1256,7 @@ This pipeline uses a {block_count}-block architecture that can be customized and
"blocks_description": blocks_description,
"components_description": components_description,
"configs_section": configs_section,
"inputs_description": inputs_description,
"outputs_description": outputs_description,
"io_specification_section": io_specification_section,
"trigger_inputs_section": trigger_inputs_section,
"tags": tags,
}

View File

@@ -699,9 +699,13 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
mask_shape = (batch_size, 1, num_frames, height, width)
if latents is not None:
conditioning_mask = latents.new_zeros(mask_shape)
conditioning_mask[:, :, 0] = 1.0
if latents.ndim == 5:
# conditioning_mask needs to the same shape as latents in two stages generation.
batch_size, _, num_frames, height, width = latents.shape
mask_shape = (batch_size, 1, num_frames, height, width)
conditioning_mask = latents.new_zeros(mask_shape)
conditioning_mask[:, :, 0] = 1.0
latents = self._normalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
@@ -710,6 +714,9 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
latents = self._pack_latents(
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
)
else:
conditioning_mask = latents.new_zeros(mask_shape)
conditioning_mask[:, :, 0] = 1.0
conditioning_mask = self._pack_latents(
conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
).squeeze(-1)

View File

@@ -107,6 +107,7 @@ def load_or_create_model_card(
widget: list[dict] | None = None,
inference: bool | None = None,
is_modular: bool = False,
update_model_card: bool = False,
) -> ModelCard:
"""
Loads or creates a model card.
@@ -133,6 +134,9 @@ def load_or_create_model_card(
`load_or_create_model_card` from a training script.
is_modular: (`bool`, optional): Boolean flag to denote if the model card is for a modular pipeline.
When True, uses model_description as-is without additional template formatting.
update_model_card: (`bool`, optional): When True, regenerates the model card content even if one
already exists on the remote repo. Existing card metadata (tags, license, etc.) is preserved. Only
supported for modular pipelines (i.e., `is_modular=True`).
"""
if not is_jinja_available():
raise ValueError(
@@ -141,9 +145,17 @@ def load_or_create_model_card(
" To install it, please run `pip install Jinja2`."
)
if update_model_card and not is_modular:
raise ValueError("`update_model_card=True` is only supported for modular pipelines (`is_modular=True`).")
try:
# Check if the model card is present on the remote repo
model_card = ModelCard.load(repo_id_or_path, token=token)
# For modular pipelines, regenerate card content when requested (preserve existing metadata)
if update_model_card and is_modular and model_description is not None:
existing_data = model_card.data
model_card = ModelCard(model_description)
model_card.data = existing_data
except (EntryNotFoundError, RepositoryNotFoundError):
# Otherwise create a model card from template
if from_training:

View File

@@ -454,8 +454,7 @@ class TestModularModelCardContent:
"blocks_description",
"components_description",
"configs_section",
"inputs_description",
"outputs_description",
"io_specification_section",
"trigger_inputs_section",
"tags",
]
@@ -552,18 +551,19 @@ class TestModularModelCardContent:
blocks = self.create_mock_blocks(inputs=inputs)
content = generate_modular_model_card_content(blocks)
assert "**Required:**" in content["inputs_description"]
assert "**Optional:**" in content["inputs_description"]
assert "prompt" in content["inputs_description"]
assert "num_steps" in content["inputs_description"]
assert "default: `50`" in content["inputs_description"]
io_section = content["io_specification_section"]
assert "**Inputs:**" in io_section
assert "prompt" in io_section
assert "num_steps" in io_section
assert "*optional*" in io_section
assert "defaults to `50`" in io_section
def test_inputs_description_empty(self):
"""Test handling of pipelines without specific inputs."""
blocks = self.create_mock_blocks(inputs=[])
content = generate_modular_model_card_content(blocks)
assert "No specific inputs defined" in content["inputs_description"]
assert "No specific inputs defined" in content["io_specification_section"]
def test_outputs_description_formatting(self):
"""Test that outputs are correctly formatted."""
@@ -573,15 +573,16 @@ class TestModularModelCardContent:
blocks = self.create_mock_blocks(outputs=outputs)
content = generate_modular_model_card_content(blocks)
assert "images" in content["outputs_description"]
assert "Generated images" in content["outputs_description"]
io_section = content["io_specification_section"]
assert "images" in io_section
assert "Generated images" in io_section
def test_outputs_description_empty(self):
"""Test handling of pipelines without specific outputs."""
blocks = self.create_mock_blocks(outputs=[])
content = generate_modular_model_card_content(blocks)
assert "Standard pipeline outputs" in content["outputs_description"]
assert "Standard pipeline outputs" in content["io_specification_section"]
def test_trigger_inputs_section_with_triggers(self):
"""Test that trigger inputs section is generated when present."""

View File

@@ -24,7 +24,8 @@ from diffusers import (
LTX2ImageToVideoPipeline,
LTX2VideoTransformer3DModel,
)
from diffusers.pipelines.ltx2 import LTX2TextConnectors
from diffusers.pipelines.ltx2 import LTX2LatentUpsamplePipeline, LTX2TextConnectors
from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
from ...testing_utils import enable_full_determinism
@@ -174,6 +175,15 @@ class LTX2ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
return components
def get_dummy_upsample_component(self, in_channels=4, mid_channels=32, num_blocks_per_stage=1):
upsampler = LTX2LatentUpsamplerModel(
in_channels=in_channels,
mid_channels=mid_channels,
num_blocks_per_stage=num_blocks_per_stage,
)
return upsampler
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
@@ -287,5 +297,60 @@ class LTX2ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4)
assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4)
def test_two_stages_inference_with_upsampler(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["output_type"] = "latent"
first_stage_output = pipe(**inputs)
video_latent = first_stage_output.frames
audio_latent = first_stage_output.audio
self.assertEqual(video_latent.shape, (1, 4, 3, 16, 16))
self.assertEqual(audio_latent.shape, (1, 2, 5, 2))
self.assertEqual(audio_latent.shape[1], components["vocoder"].config.out_channels)
upsampler = self.get_dummy_upsample_component(in_channels=video_latent.shape[1])
upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=upsampler)
upscaled_video_latent = upsample_pipe(latents=video_latent, output_type="latent", return_dict=False)[0]
self.assertEqual(upscaled_video_latent.shape, (1, 4, 3, 32, 32))
inputs["latents"] = upscaled_video_latent
inputs["audio_latents"] = audio_latent
inputs["output_type"] = "pt"
second_stage_output = pipe(**inputs)
video = second_stage_output.frames
audio = second_stage_output.audio
self.assertEqual(video.shape, (1, 5, 3, 64, 64))
self.assertEqual(audio.shape[0], 1)
self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels)
# fmt: off
expected_video_slice = torch.tensor(
[
0.4497, 0.6757, 0.4219, 0.7686, 0.4525, 0.6483, 0.3969, 0.7404, 0.3541, 0.3039, 0.4592, 0.3521, 0.3665, 0.2785, 0.3336, 0.3079
]
)
expected_audio_slice = torch.tensor(
[
0.0271, 0.0492, 0.1249, 0.1126, 0.1661, 0.1060, 0.1717, 0.0944, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739
]
)
# fmt: on
video = video.flatten()
audio = audio.flatten()
generated_video_slice = torch.cat([video[:8], video[-8:]])
generated_audio_slice = torch.cat([audio[:8], audio[-8:]])
assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4)
assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4)
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=2e-2)