Compare commits

...

7 Commits

Author SHA1 Message Date
DN6
5c99566bab update 2026-03-01 12:46:45 +05:30
YiYi Xu
39188248a7 [modular] fallback to default_blocks_name when loading base block classes in ModularPipeline (#13193)
up

Co-authored-by: yiyi@huggingface.co <yiyi@ip-26-0-161-123.ec2.internal>
2026-02-27 18:58:01 -10:00
Sayak Paul
9b97932424 [tests] consistency tests for modular index (#13192)
* add a test to check modular index consistency

* check for compulsory keys.
2026-02-28 08:47:21 +05:30
YiYi Xu
680076fcc0 [Modular] update the auto pipeline blocks doc (#13148)
* update

* Apply suggestion from @yiyixuxu

* Update docs/source/en/modular_diffusers/auto_pipeline_blocks.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/modular_diffusers/auto_pipeline_blocks.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/modular_diffusers/auto_pipeline_blocks.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/modular_diffusers/auto_pipeline_blocks.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* add to api

---------

Co-authored-by: yiyi@huggingface.co <yiyi@ip-26-0-160-103.ec2.internal>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: yiyi@huggingface.co <yiyi@ip-26-0-161-123.ec2.internal>
2026-02-27 10:50:35 -10:00
Christopher
5910a1cc6c Fixing Kohya loras loading: Flux.1-dev loras with TE ("lora_te1_" prefix) (#13188)
* fixing text encoder lora loading

* following Cursor's review
2026-02-27 15:43:41 +05:30
Jerry Song
40e96454f1 Fix LTX-2 image-to-video generation failure in two stages generation (#13187)
* Fix LTX-2 image-to-video generation failure in two stages generation

In LTX-2's two-stage image-to-video generation task, specifically after
the upsampling step, a shape mismatch occurs between the `latents` and
the `conditioning_mask`, which causes an error in function
`_create_noised_state`.

Fix it by creating the `conditioning_mask` based on the shape of the
`latents`.

* Add unit test for LTX-2 i2v two stages inference with upsampler

* Downscaling the upsampler in LTX-2 image-to-video unit test

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2026-02-27 00:55:01 -08:00
Varun Chawla
47455bd133 Fix Flash Attention 3 interface for new FA3 return format (#13173)
* Fix Flash Attention 3 interface compatibility for new FA3 versions

Newer versions of flash-attn (after Dao-AILab/flash-attention@ed20940)
no longer return lse by default from flash_attn_3_func. The function
now returns just the output tensor unless return_attn_probs=True is
passed.

Updated _wrapped_flash_attn_3 and _flash_varlen_attention_3 to pass
return_attn_probs and handle both old (always tuple) and new (tensor
or tuple) return formats gracefully.

Fixes #12022

* Simplify _wrapped_flash_attn_3 return unpacking

Since return_attn_probs=True is always passed, the result is
guaranteed to be a tuple. Remove the unnecessary isinstance guard.
2026-02-26 17:34:36 +05:30
10 changed files with 342 additions and 13 deletions

View File

@@ -14,4 +14,8 @@
## AutoPipelineBlocks
[[autodoc]] diffusers.modular_pipelines.modular_pipeline.AutoPipelineBlocks
[[autodoc]] diffusers.modular_pipelines.modular_pipeline.AutoPipelineBlocks
## ConditionalPipelineBlocks
[[autodoc]] diffusers.modular_pipelines.modular_pipeline.ConditionalPipelineBlocks

View File

@@ -121,7 +121,7 @@ from diffusers.modular_pipelines import AutoPipelineBlocks
class AutoImageBlocks(AutoPipelineBlocks):
# List of sub-block classes to choose from
block_classes = [block_inpaint_cls, block_i2i_cls, block_t2i_cls]
block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock]
# Names for each block in the same order
block_names = ["inpaint", "img2img", "text2img"]
# Trigger inputs that determine which block to run
@@ -129,8 +129,8 @@ class AutoImageBlocks(AutoPipelineBlocks):
# - "image" triggers img2img workflow (but only if mask is not provided)
# - if none of above, runs the text2img workflow (default)
block_trigger_inputs = ["mask", "image", None]
# Description is extremely important for AutoPipelineBlocks
@property
def description(self):
return (
"Pipeline generates images given different types of conditions!\n"
@@ -141,7 +141,7 @@ class AutoImageBlocks(AutoPipelineBlocks):
)
```
It is **very** important to include a `description` to avoid any confusion over how to run a block and what inputs are required. While [`~modular_pipelines.AutoPipelineBlocks`] are convenient, it's conditional logic may be difficult to figure out if it isn't properly explained.
It is **very** important to include a `description` to avoid any confusion over how to run a block and what inputs are required. While [`~modular_pipelines.AutoPipelineBlocks`] are convenient, its conditional logic may be difficult to figure out if it isn't properly explained.
Create an instance of `AutoImageBlocks`.
@@ -152,5 +152,74 @@ auto_blocks = AutoImageBlocks()
For more complex compositions, such as nested [`~modular_pipelines.AutoPipelineBlocks`] blocks when they're used as sub-blocks in larger pipelines, use the [`~modular_pipelines.SequentialPipelineBlocks.get_execution_blocks`] method to extract the a block that is actually run based on your input.
```py
auto_blocks.get_execution_blocks("mask")
auto_blocks.get_execution_blocks(mask=True)
```
## ConditionalPipelineBlocks
[`~modular_pipelines.AutoPipelineBlocks`] is a special case of [`~modular_pipelines.ConditionalPipelineBlocks`]. While [`~modular_pipelines.AutoPipelineBlocks`] selects blocks based on whether a trigger input is provided or not, [`~modular_pipelines.ConditionalPipelineBlocks`] is able to select a block based on custom selection logic provided in the `select_block` method.
Here is the same example written using [`~modular_pipelines.ConditionalPipelineBlocks`] directly:
```py
from diffusers.modular_pipelines import ConditionalPipelineBlocks
class AutoImageBlocks(ConditionalPipelineBlocks):
block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock]
block_names = ["inpaint", "img2img", "text2img"]
block_trigger_inputs = ["mask", "image"]
default_block_name = "text2img"
@property
def description(self):
return (
"Pipeline generates images given different types of conditions!\n"
+ "This is an auto pipeline block that works for text2img, img2img and inpainting tasks.\n"
+ " - inpaint workflow is run when `mask` is provided.\n"
+ " - img2img workflow is run when `image` is provided (but only when `mask` is not provided).\n"
+ " - text2img workflow is run when neither `image` nor `mask` is provided.\n"
)
def select_block(self, mask=None, image=None) -> str | None:
if mask is not None:
return "inpaint"
if image is not None:
return "img2img"
return None # falls back to default_block_name ("text2img")
```
The inputs listed in `block_trigger_inputs` are passed as keyword arguments to `select_block()`. When `select_block` returns `None`, it falls back to `default_block_name`. If `default_block_name` is also `None`, the entire conditional block is skipped — this is useful for optional processing steps that should only run when specific inputs are provided.
## Workflows
Pipelines that contain conditional blocks ([`~modular_pipelines.AutoPipelineBlocks`] or [`~modular_pipelines.ConditionalPipelineBlocks]`) can support multiple workflows — for example, our SDXL modular pipeline supports a dozen workflows all in one pipeline. But this also means it can be confusing for users to know what workflows are supported and how to run them. For pipeline builders, it's useful to be able to extract only the blocks relevant to a specific workflow.
We recommend defining a `_workflow_map` to give each workflow a name and explicitly list the inputs it requires.
```py
from diffusers.modular_pipelines import SequentialPipelineBlocks
class MyPipelineBlocks(SequentialPipelineBlocks):
block_classes = [TextEncoderBlock, AutoImageBlocks, DecodeBlock]
block_names = ["text_encoder", "auto_image", "decode"]
_workflow_map = {
"text2image": {"prompt": True},
"image2image": {"image": True, "prompt": True},
"inpaint": {"mask": True, "image": True, "prompt": True},
}
```
All of our built-in modular pipelines come with pre-defined workflows. The `available_workflows` property lists all supported workflows:
```py
pipeline_blocks = MyPipelineBlocks()
pipeline_blocks.available_workflows
# ['text2image', 'image2image', 'inpaint']
```
Retrieve a specific workflow with `get_workflow` to inspect and debug a specific block that executes the workflow.
```py
pipeline_blocks.get_workflow("inpaint")
```

View File

@@ -648,6 +648,28 @@ class ConfigMixin:
)
return config_file
@classmethod
def _get_dataclass_from_config(cls, config_dict: dict[str, Any]):
sig = inspect.signature(cls.__init__)
fields = []
for name, param in sig.parameters.items():
if name == "self" or name == "kwargs" or name in cls.ignore_for_config:
continue
annotation = param.annotation if param.annotation is not inspect.Parameter.empty else Any
if param.default is not inspect.Parameter.empty:
fields.append((name, annotation, dataclasses.field(default=param.default)))
else:
fields.append((name, annotation))
dc_cls = dataclasses.make_dataclass(
f"{cls.__name__}Config",
fields,
frozen=True,
)
valid_fields = {f.name for f in dataclasses.fields(dc_cls)}
init_kwargs = {k: v for k, v in config_dict.items() if k in valid_fields}
return dc_cls(**init_kwargs)
def register_to_config(init):
r"""

View File

@@ -856,7 +856,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
)
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("text_encoders.t5xxl.transformer.")}
has_diffb = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_")) for k in state_dict)
has_diffb = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_", "lora_te1_")) for k in state_dict)
if has_diffb:
zero_status_diff_b = state_dict_all_zero(state_dict, ".diff_b")
if zero_status_diff_b:
@@ -895,7 +895,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
state_dict = {
_custom_replace(k, limit_substrings): v
for k, v in state_dict.items()
if k.startswith(("lora_unet_", "lora_te_"))
if k.startswith(("lora_unet_", "lora_te_", "lora_te1_"))
}
if any("text_projection" in k for k in state_dict):

View File

@@ -733,7 +733,7 @@ def _wrapped_flash_attn_3(
) -> tuple[torch.Tensor, torch.Tensor]:
# Hardcoded for now because pytorch does not support tuple/int type hints
window_size = (-1, -1)
out, lse, *_ = flash_attn_3_func(
result = flash_attn_3_func(
q=q,
k=k,
v=v,
@@ -750,7 +750,9 @@ def _wrapped_flash_attn_3(
pack_gqa=pack_gqa,
deterministic=deterministic,
sm_margin=sm_margin,
return_attn_probs=True,
)
out, lse, *_ = result
lse = lse.permute(0, 2, 1)
return out, lse
@@ -2701,7 +2703,7 @@ def _flash_varlen_attention_3(
key_packed = torch.cat(key_valid, dim=0)
value_packed = torch.cat(value_valid, dim=0)
out, lse, *_ = flash_attn_3_varlen_func(
result = flash_attn_3_varlen_func(
q=query_packed,
k=key_packed,
v=value_packed,
@@ -2711,7 +2713,13 @@ def _flash_varlen_attention_3(
max_seqlen_k=max_seqlen_k,
softmax_scale=scale,
causal=is_causal,
return_attn_probs=return_lse,
)
if isinstance(result, tuple):
out, lse, *_ = result
else:
out = result
lse = None
out = out.unflatten(0, (batch_size, -1))
return (out, lse) if return_lse else out

View File

@@ -1633,7 +1633,14 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
blocks_class_name = self.default_blocks_name
if blocks_class_name is not None:
diffusers_module = importlib.import_module("diffusers")
blocks_class = getattr(diffusers_module, blocks_class_name)
blocks_class = getattr(diffusers_module, blocks_class_name, None)
# If the blocks_class is not found or is a base class (e.g. SequentialPipelineBlocks saved by from_blocks_dict) with empty block_classes
# fall back to default_blocks_name
if blocks_class is None or not blocks_class.block_classes:
blocks_class_name = self.default_blocks_name
blocks_class = getattr(diffusers_module, blocks_class_name)
if blocks_class is not None:
blocks = blocks_class()
else:
logger.warning(f"`blocks` is `None`, no default blocks class found for {self.__class__.__name__}")

View File

@@ -699,9 +699,13 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
mask_shape = (batch_size, 1, num_frames, height, width)
if latents is not None:
conditioning_mask = latents.new_zeros(mask_shape)
conditioning_mask[:, :, 0] = 1.0
if latents.ndim == 5:
# conditioning_mask needs to the same shape as latents in two stages generation.
batch_size, _, num_frames, height, width = latents.shape
mask_shape = (batch_size, 1, num_frames, height, width)
conditioning_mask = latents.new_zeros(mask_shape)
conditioning_mask[:, :, 0] = 1.0
latents = self._normalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
@@ -710,6 +714,9 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
latents = self._pack_latents(
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
)
else:
conditioning_mask = latents.new_zeros(mask_shape)
conditioning_mask[:, :, 0] = 1.0
conditioning_mask = self._pack_latents(
conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
).squeeze(-1)

View File

@@ -1,4 +1,6 @@
import gc
import json
import os
import tempfile
from typing import Callable
@@ -349,6 +351,33 @@ class ModularPipelineTesterMixin:
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
def test_modular_index_consistency(self):
pipe = self.get_pipeline()
components_spec = pipe._component_specs
components = sorted(components_spec.keys())
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir)
index_file = os.path.join(tmpdir, "modular_model_index.json")
assert os.path.exists(index_file)
with open(index_file) as f:
index_contents = json.load(f)
compulsory_keys = {"_blocks_class_name", "_class_name", "_diffusers_version"}
for k in compulsory_keys:
assert k in index_contents
to_check_attrs = {"pretrained_model_name_or_path", "revision", "subfolder"}
for component in components:
spec = components_spec[component]
for attr in to_check_attrs:
if getattr(spec, "pretrained_model_name_or_path", None) is not None:
for attr in to_check_attrs:
assert component in index_contents, f"{component} should be present in index but isn't."
attr_value_from_index = index_contents[component][2][attr]
assert getattr(spec, attr) == attr_value_from_index
def test_workflow_map(self):
blocks = self.pipeline_blocks_class()
if blocks._workflow_map is None:
@@ -699,3 +728,27 @@ class TestLoadComponentsSkipBehavior:
# Verify test_component was not loaded
assert not hasattr(pipe, "test_component") or pipe.test_component is None
class TestModularPipelineInitFallback:
"""Test that ModularPipeline.__init__ falls back to default_blocks_name when
_blocks_class_name is a base class (e.g. SequentialPipelineBlocks saved by from_blocks_dict)."""
def test_init_fallback_when_blocks_class_name_is_base_class(self, tmp_path):
# 1. Load pipeline and get a workflow (returns a base SequentialPipelineBlocks)
pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
t2i_blocks = pipe.blocks.get_workflow("text2image")
assert t2i_blocks.__class__.__name__ == "SequentialPipelineBlocks"
# 2. Use init_pipeline to create a new pipeline from the workflow blocks
t2i_pipe = t2i_blocks.init_pipeline("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
# 3. Save and reload — the saved config will have _blocks_class_name="SequentialPipelineBlocks"
save_dir = str(tmp_path / "pipeline")
t2i_pipe.save_pretrained(save_dir)
loaded_pipe = ModularPipeline.from_pretrained(save_dir)
# 4. Verify it fell back to default_blocks_name and has correct blocks
assert loaded_pipe.__class__.__name__ == pipe.__class__.__name__
assert loaded_pipe._blocks.__class__.__name__ == pipe._blocks.__class__.__name__
assert len(loaded_pipe._blocks.sub_blocks) == len(pipe._blocks.sub_blocks)

View File

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import json
import tempfile
import unittest
@@ -305,3 +306,96 @@ class ConfigTester(unittest.TestCase):
result = json.loads(json_string)
assert result["test_file_1"] == config.config.test_file_1.as_posix()
assert result["test_file_2"] == config.config.test_file_2.as_posix()
class SampleObjectTyped(ConfigMixin):
config_name = "config.json"
@register_to_config
def __init__(
self,
a: int = 2,
b: int = 5,
c: str = "hello",
):
pass
class SampleObjectWithIgnore(ConfigMixin):
config_name = "config.json"
ignore_for_config = ["secret"]
@register_to_config
def __init__(
self,
a: int = 2,
secret: str = "hidden",
):
pass
class DataclassFromConfigTester(unittest.TestCase):
def test_get_dataclass_from_config_returns_frozen_dataclass(self):
obj = SampleObject()
tc = SampleObject._get_dataclass_from_config(dict(obj.config))
assert dataclasses.is_dataclass(tc)
with self.assertRaises(dataclasses.FrozenInstanceError):
tc.a = 99
def test_get_dataclass_from_config_class_name(self):
obj = SampleObject()
tc = SampleObject._get_dataclass_from_config(dict(obj.config))
assert type(tc).__name__ == "SampleObjectConfig"
def test_get_dataclass_from_config_values_match_config(self):
obj = SampleObject(a=10, b=20)
tc = SampleObject._get_dataclass_from_config(dict(obj.config))
assert tc.a == 10
assert tc.b == 20
assert tc.c == (2, 5)
assert tc.d == "for diffusion"
assert tc.e == [1, 3]
def test_get_dataclass_from_config_from_raw_dict(self):
tc = SampleObjectTyped._get_dataclass_from_config({"a": 7, "b": 3, "c": "world"})
assert tc.a == 7
assert tc.b == 3
assert tc.c == "world"
def test_get_dataclass_from_config_annotations(self):
tc = SampleObjectTyped._get_dataclass_from_config({"a": 1, "b": 2, "c": "hi"})
fields = {f.name: f.type for f in dataclasses.fields(tc)}
assert fields["a"] is int
assert fields["b"] is int
assert fields["c"] is str
def test_get_dataclass_from_config_asdict_roundtrip(self):
tc = SampleObjectTyped._get_dataclass_from_config({"a": 7, "b": 3, "c": "world"})
d = dataclasses.asdict(tc)
assert d == {"a": 7, "b": 3, "c": "world"}
def test_get_dataclass_from_config_ignores_extra_keys(self):
tc = SampleObjectTyped._get_dataclass_from_config(
{"a": 1, "b": 2, "c": "hi", "_class_name": "Foo", "extra": 99}
)
assert tc.a == 1
assert not hasattr(tc, "_class_name")
assert not hasattr(tc, "extra")
def test_get_dataclass_from_config_respects_ignore_for_config(self):
tc = SampleObjectWithIgnore._get_dataclass_from_config({"a": 5})
assert not hasattr(tc, "secret")
assert tc.a == 5
def test_get_dataclass_from_config_works_for_scheduler(self):
scheduler = DDIMScheduler()
tc = DDIMScheduler._get_dataclass_from_config(dict(scheduler.config))
assert dataclasses.is_dataclass(tc)
assert type(tc).__name__ == "DDIMSchedulerConfig"
assert tc.num_train_timesteps == scheduler.config.num_train_timesteps
def test_get_dataclass_from_config_different_values(self):
tc1 = SampleObjectTyped._get_dataclass_from_config({"a": 1, "b": 2, "c": "x"})
tc2 = SampleObjectTyped._get_dataclass_from_config({"a": 9, "b": 8, "c": "y"})
assert tc1.a == 1
assert tc2.a == 9

View File

@@ -24,7 +24,8 @@ from diffusers import (
LTX2ImageToVideoPipeline,
LTX2VideoTransformer3DModel,
)
from diffusers.pipelines.ltx2 import LTX2TextConnectors
from diffusers.pipelines.ltx2 import LTX2LatentUpsamplePipeline, LTX2TextConnectors
from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
from ...testing_utils import enable_full_determinism
@@ -174,6 +175,15 @@ class LTX2ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
return components
def get_dummy_upsample_component(self, in_channels=4, mid_channels=32, num_blocks_per_stage=1):
upsampler = LTX2LatentUpsamplerModel(
in_channels=in_channels,
mid_channels=mid_channels,
num_blocks_per_stage=num_blocks_per_stage,
)
return upsampler
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
@@ -287,5 +297,60 @@ class LTX2ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4)
assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4)
def test_two_stages_inference_with_upsampler(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["output_type"] = "latent"
first_stage_output = pipe(**inputs)
video_latent = first_stage_output.frames
audio_latent = first_stage_output.audio
self.assertEqual(video_latent.shape, (1, 4, 3, 16, 16))
self.assertEqual(audio_latent.shape, (1, 2, 5, 2))
self.assertEqual(audio_latent.shape[1], components["vocoder"].config.out_channels)
upsampler = self.get_dummy_upsample_component(in_channels=video_latent.shape[1])
upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=upsampler)
upscaled_video_latent = upsample_pipe(latents=video_latent, output_type="latent", return_dict=False)[0]
self.assertEqual(upscaled_video_latent.shape, (1, 4, 3, 32, 32))
inputs["latents"] = upscaled_video_latent
inputs["audio_latents"] = audio_latent
inputs["output_type"] = "pt"
second_stage_output = pipe(**inputs)
video = second_stage_output.frames
audio = second_stage_output.audio
self.assertEqual(video.shape, (1, 5, 3, 64, 64))
self.assertEqual(audio.shape[0], 1)
self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels)
# fmt: off
expected_video_slice = torch.tensor(
[
0.4497, 0.6757, 0.4219, 0.7686, 0.4525, 0.6483, 0.3969, 0.7404, 0.3541, 0.3039, 0.4592, 0.3521, 0.3665, 0.2785, 0.3336, 0.3079
]
)
expected_audio_slice = torch.tensor(
[
0.0271, 0.0492, 0.1249, 0.1126, 0.1661, 0.1060, 0.1717, 0.0944, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739
]
)
# fmt: on
video = video.flatten()
audio = audio.flatten()
generated_video_slice = torch.cat([video[:8], video[-8:]])
generated_audio_slice = torch.cat([audio[:8], audio[-8:]])
assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4)
assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4)
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=2e-2)