Compare commits

..

2 Commits

Author SHA1 Message Date
DN6
52bd90a3b3 update 2026-03-11 13:40:14 +05:30
DN6
b28e9204f7 update 2026-03-11 13:32:16 +05:30
11 changed files with 200 additions and 359 deletions

View File

@@ -532,6 +532,8 @@
title: ControlNet-XS with Stable Diffusion XL
- local: api/pipelines/controlnet_union
title: ControlNetUnion
- local: api/pipelines/cosmos
title: Cosmos
- local: api/pipelines/ddim
title: DDIM
- local: api/pipelines/ddpm
@@ -675,8 +677,6 @@
title: CogVideoX
- local: api/pipelines/consisid
title: ConsisID
- local: api/pipelines/cosmos
title: Cosmos
- local: api/pipelines/framepack
title: Framepack
- local: api/pipelines/helios

View File

@@ -21,31 +21,29 @@
> [!TIP]
> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## Basic usage
## Loading original format checkpoints
Original format checkpoints that have not been converted to diffusers-expected format can be loaded using the `from_single_file` method.
```python
import torch
from diffusers import Cosmos2_5_PredictBasePipeline
from diffusers.utils import export_to_video
from diffusers import Cosmos2TextToImagePipeline, CosmosTransformer3DModel
model_id = "nvidia/Cosmos-Predict2.5-2B"
pipe = Cosmos2_5_PredictBasePipeline.from_pretrained(
model_id, revision="diffusers/base/post-trained", torch_dtype=torch.bfloat16
)
model_id = "nvidia/Cosmos-Predict2-2B-Text2Image"
transformer = CosmosTransformer3DModel.from_single_file(
"https://huggingface.co/nvidia/Cosmos-Predict2-2B-Text2Image/blob/main/model.pt",
torch_dtype=torch.bfloat16,
).to("cuda")
pipe = Cosmos2TextToImagePipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16)
pipe.to("cuda")
prompt = "As the red light shifts to green, the red bus at the intersection begins to move forward, its headlights cutting through the falling snow. The snowy tire tracks deepen as the vehicle inches ahead, casting fresh lines onto the slushy road. Around it, streetlights glow warmer, illuminating the drifting flakes and wet reflections on the asphalt. Other cars behind start to edge forward, their beams joining the scene. The stillness of the urban street transitions into motion as the quiet snowfall is punctuated by the slow advance of traffic through the frosty city corridor."
prompt = "A close-up shot captures a vibrant yellow scrubber vigorously working on a grimy plate, its bristles moving in circular motions to lift stubborn grease and food residue. The dish, once covered in remnants of a hearty meal, gradually reveals its original glossy surface. Suds form and bubble around the scrubber, creating a satisfying visual of cleanliness in progress. The sound of scrubbing fills the air, accompanied by the gentle clinking of the dish against the sink. As the scrubber continues its task, the dish transforms, gleaming under the bright kitchen lights, symbolizing the triumph of cleanliness over mess."
negative_prompt = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality."
output = pipe(
image=None,
video=None,
prompt=prompt,
negative_prompt=negative_prompt,
num_frames=93,
generator=torch.Generator().manual_seed(1),
).frames[0]
export_to_video(output, "text2world.mp4", fps=16)
prompt=prompt, negative_prompt=negative_prompt, generator=torch.Generator().manual_seed(1)
).images[0]
output.save("output.png")
```
## Cosmos2_5_TransferPipeline

View File

@@ -44,7 +44,6 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
| [ControlNet with Stable Diffusion XL](controlnet_sdxl) | text2image |
| [ControlNet-XS](controlnetxs) | text2image |
| [ControlNet-XS with Stable Diffusion XL](controlnetxs_sdxl) | text2image |
| [Cosmos](cosmos) | text2video, video2video |
| [Dance Diffusion](dance_diffusion) | unconditional audio generation |
| [DDIM](ddim) | unconditional image generation |
| [DDPM](ddpm) | unconditional image generation |

View File

@@ -2538,12 +2538,8 @@ def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
def get_alpha_scales(down_weight, alpha_key):
rank = down_weight.shape[0]
alpha_tensor = state_dict.pop(alpha_key, None)
if alpha_tensor is None:
return 1.0, 1.0
scale = (
alpha_tensor.item() / rank
) # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
alpha = state_dict.pop(alpha_key).item()
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:

View File

@@ -60,16 +60,6 @@ class ContextParallelConfig:
rotate_method (`str`, *optional*, defaults to `"allgather"`):
Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"`
is supported.
ulysses_anything (`bool`, *optional*, defaults to `False`):
Whether to enable "Ulysses Anything" mode, which supports arbitrary sequence lengths and head counts that
are not evenly divisible by `ulysses_degree`. When enabled, `ulysses_degree` must be greater than 1 and
`ring_degree` must be 1.
mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*):
A custom device mesh to use for context parallelism. If provided, this mesh will be used instead of
creating a new one. This is useful when combining context parallelism with other parallelism strategies
(e.g., FSDP, tensor parallelism) that share the same device mesh. The mesh must have both "ring" and
"ulysses" dimensions. Use size 1 for dimensions not being used (e.g., `mesh_shape=(2, 1, 4)` with
`mesh_dim_names=("ring", "ulysses", "fsdp")` for ring attention only with FSDP).
"""
@@ -78,7 +68,6 @@ class ContextParallelConfig:
convert_to_fp32: bool = True
# TODO: support alltoall
rotate_method: Literal["allgather", "alltoall"] = "allgather"
mesh: torch.distributed.device_mesh.DeviceMesh | None = None
# Whether to enable ulysses anything attention to support
# any sequence lengths and any head numbers.
ulysses_anything: bool = False
@@ -135,7 +124,7 @@ class ContextParallelConfig:
f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})."
)
self._flattened_mesh = self._mesh["ring", "ulysses"]._flatten()
self._flattened_mesh = self._mesh._flatten()
self._ring_mesh = self._mesh["ring"]
self._ulysses_mesh = self._mesh["ulysses"]
self._ring_local_rank = self._ring_mesh.get_local_rank()

View File

@@ -1567,7 +1567,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
mesh = None
if config.context_parallel_config is not None:
cp_config = config.context_parallel_config
mesh = cp_config.mesh or torch.distributed.device_mesh.init_device_mesh(
mesh = torch.distributed.device_mesh.init_device_mesh(
device_type=device_type,
mesh_shape=cp_config.mesh_shape,
mesh_dim_names=cp_config.mesh_dim_names,

View File

@@ -95,7 +95,6 @@ from .pag import (
StableDiffusionXLPAGPipeline,
)
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
from .prx import PRXPipeline
from .qwenimage import (
QwenImageControlNetPipeline,
QwenImageEditInpaintPipeline,
@@ -186,7 +185,6 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("z-image-controlnet-inpaint", ZImageControlNetInpaintPipeline),
("z-image-omni", ZImageOmniPipeline),
("ovis", OvisImagePipeline),
("prx", PRXPipeline),
]
)

View File

@@ -82,16 +82,13 @@ EXAMPLE_DOC_STRING = """
```python
>>> import cv2
>>> import numpy as np
>>> from PIL import Image
>>> import torch
>>> from diffusers import Cosmos2_5_TransferPipeline, AutoModel
>>> from diffusers.utils import export_to_video, load_video
>>> model_id = "nvidia/Cosmos-Transfer2.5-2B"
>>> # Load a Transfer2.5 controlnet variant (edge, depth, seg, or blur)
>>> controlnet = AutoModel.from_pretrained(
... model_id, revision="diffusers/controlnet/general/edge", torch_dtype=torch.bfloat16
... )
>>> controlnet = AutoModel.from_pretrained(model_id, revision="diffusers/controlnet/general/edge")
>>> pipe = Cosmos2_5_TransferPipeline.from_pretrained(
... model_id, controlnet=controlnet, revision="diffusers/general", torch_dtype=torch.bfloat16
... )

View File

@@ -60,7 +60,12 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
model.eval()
# Move inputs to device
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
inputs_on_device = {}
for key, value in inputs_dict.items():
if isinstance(value, torch.Tensor):
inputs_on_device[key] = value.to(device)
else:
inputs_on_device[key] = value
# Enable context parallelism
cp_config = ContextParallelConfig(**cp_dict)
@@ -84,59 +89,6 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
dist.destroy_process_group()
def _custom_mesh_worker(
rank,
world_size,
master_port,
model_class,
init_dict,
cp_dict,
mesh_shape,
mesh_dim_names,
inputs_dict,
return_dict,
):
"""Worker function for context parallel testing with a user-provided custom DeviceMesh."""
try:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(master_port)
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
device = torch.device(f"cuda:{rank}")
model = model_class(**init_dict)
model.to(device)
model.eval()
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
# DeviceMesh must be created after init_process_group, inside each worker process.
mesh = torch.distributed.device_mesh.init_device_mesh(
"cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
)
cp_config = ContextParallelConfig(**cp_dict, mesh=mesh)
model.enable_parallelism(config=cp_config)
with torch.no_grad():
output = model(**inputs_on_device, return_dict=False)[0]
if rank == 0:
return_dict["status"] = "success"
return_dict["output_shape"] = list(output.shape)
except Exception as e:
if rank == 0:
return_dict["status"] = "error"
return_dict["error"] = str(e)
finally:
if dist.is_initialized():
dist.destroy_process_group()
@is_context_parallel
@require_torch_multi_accelerator
class ContextParallelTesterMixin:
@@ -174,48 +126,3 @@ class ContextParallelTesterMixin:
assert return_dict.get("status") == "success", (
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
)
@pytest.mark.parametrize(
"cp_type,mesh_shape,mesh_dim_names",
[
("ring_degree", (2, 1, 1), ("ring", "ulysses", "fsdp")),
("ulysses_degree", (1, 2, 1), ("ring", "ulysses", "fsdp")),
],
ids=["ring-3d-fsdp", "ulysses-3d-fsdp"],
)
def test_context_parallel_custom_mesh(self, cp_type, mesh_shape, mesh_dim_names):
if not torch.distributed.is_available():
pytest.skip("torch.distributed is not available.")
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
world_size = 2
init_dict = self.get_init_dict()
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in self.get_dummy_inputs().items()}
cp_dict = {cp_type: world_size}
master_port = _find_free_port()
manager = mp.Manager()
return_dict = manager.dict()
mp.spawn(
_custom_mesh_worker,
args=(
world_size,
master_port,
self.model_class,
init_dict,
cp_dict,
mesh_shape,
mesh_dim_names,
inputs_dict,
return_dict,
),
nprocs=world_size,
join=True,
)
assert return_dict.get("status") == "success", (
f"Custom mesh context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
)

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,59 +12,47 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import LTXVideoTransformer3DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
from ..testing_utils import (
BaseModelTesterConfig,
MemoryTesterMixin,
ModelTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
enable_full_determinism()
class LTXTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = LTXVideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
class LTXTransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return LTXVideoTransformer3DModel
@property
def dummy_input(self):
batch_size = 2
num_channels = 4
num_frames = 2
height = 16
width = 16
embedding_dim = 16
sequence_length = 16
def output_shape(self) -> tuple[int, int]:
return (512, 4)
hidden_states = torch.randn((batch_size, num_frames * height * width, num_channels)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
@property
def input_shape(self) -> tuple[int, int]:
return (512, 4)
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self):
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep,
"encoder_attention_mask": encoder_attention_mask,
"num_frames": num_frames,
"height": height,
"width": width,
}
@property
def input_shape(self):
return (512, 4)
@property
def output_shape(self):
return (512, 4)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 4,
"out_channels": 4,
"num_attention_heads": 2,
@@ -75,16 +62,57 @@ class LTXTransformerTests(ModelTesterMixin, unittest.TestCase):
"qk_norm": "rms_norm_across_heads",
"caption_channels": 16,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 2
num_channels = 4
num_frames = 2
height = 16
width = 16
embedding_dim = 16
sequence_length = 16
return {
"hidden_states": randn_tensor(
(batch_size, num_frames * height * width, num_channels),
generator=self.generator,
device=torch_device,
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
"encoder_attention_mask": torch.ones((batch_size, sequence_length)).bool().to(torch_device),
"num_frames": num_frames,
"height": height,
"width": width,
}
class TestLTXTransformer(LTXTransformerTesterConfig, ModelTesterMixin):
"""Core model tests for LTX Video Transformer."""
class TestLTXTransformerMemory(LTXTransformerTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for LTX Video Transformer."""
class TestLTXTransformerTraining(LTXTransformerTesterConfig, TrainingTesterMixin):
"""Training tests for LTX Video Transformer."""
def test_gradient_checkpointing_is_applied(self):
expected_set = {"LTXVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
super().test_gradient_checkpointing_is_applied(expected_set={"LTXVideoTransformer3DModel"})
class LTXTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = LTXVideoTransformer3DModel
class TestLTXTransformerCompile(LTXTransformerTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for LTX Video Transformer."""
def prepare_init_args_and_inputs_for_common(self):
return LTXTransformerTests().prepare_init_args_and_inputs_for_common()
# TODO: Add pretrained_model_name_or_path once a tiny LTX model is available on the Hub
# class TestLTXTransformerBitsAndBytes(LTXTransformerTesterConfig, BitsAndBytesTesterMixin):
# """BitsAndBytes quantization tests for LTX Video Transformer."""
# TODO: Add pretrained_model_name_or_path once a tiny LTX model is available on the Hub
# class TestLTXTransformerTorchAo(LTXTransformerTesterConfig, TorchAoTesterMixin):
# """TorchAo quantization tests for LTX Video Transformer."""

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,77 +12,49 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import pytest
import torch
from diffusers import LTX2VideoTransformer3DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
MemoryTesterMixin,
ModelTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
enable_full_determinism()
class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = LTX2VideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
class LTX2TransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return LTX2VideoTransformer3DModel
@property
def dummy_input(self):
# Common
batch_size = 2
def output_shape(self) -> tuple[int, int]:
return (512, 4)
# Video
num_frames = 2
num_channels = 4
height = 16
width = 16
@property
def input_shape(self) -> tuple[int, int]:
return (512, 4)
# Audio
audio_num_frames = 9
audio_num_channels = 2
num_mel_bins = 2
@property
def main_input_name(self) -> str:
return "hidden_states"
# Text
embedding_dim = 16
sequence_length = 16
hidden_states = torch.randn((batch_size, num_frames * height * width, num_channels)).to(torch_device)
audio_hidden_states = torch.randn((batch_size, audio_num_frames, audio_num_channels * num_mel_bins)).to(
torch_device
)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
audio_encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device)
timestep = torch.rand((batch_size,)).to(torch_device) * 1000
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self):
return {
"hidden_states": hidden_states,
"audio_hidden_states": audio_hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"audio_encoder_hidden_states": audio_encoder_hidden_states,
"timestep": timestep,
"encoder_attention_mask": encoder_attention_mask,
"num_frames": num_frames,
"height": height,
"width": width,
"audio_num_frames": audio_num_frames,
"fps": 25.0,
}
@property
def input_shape(self):
return (512, 4)
@property
def output_shape(self):
return (512, 4)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 4,
"out_channels": 4,
"patch_size": 1,
@@ -101,122 +72,80 @@ class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase):
"caption_channels": 16,
"rope_double_precision": False,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 2
num_frames = 2
num_channels = 4
height = 16
width = 16
audio_num_frames = 9
audio_num_channels = 2
num_mel_bins = 2
embedding_dim = 16
sequence_length = 16
return {
"hidden_states": randn_tensor(
(batch_size, num_frames * height * width, num_channels),
generator=self.generator,
device=torch_device,
),
"audio_hidden_states": randn_tensor(
(batch_size, audio_num_frames, audio_num_channels * num_mel_bins),
generator=self.generator,
device=torch_device,
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
),
"audio_encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
),
"timestep": (randn_tensor((batch_size,), generator=self.generator, device=torch_device).abs() * 1000),
"encoder_attention_mask": torch.ones((batch_size, sequence_length)).bool().to(torch_device),
"num_frames": num_frames,
"height": height,
"width": width,
"audio_num_frames": audio_num_frames,
"fps": 25.0,
}
class TestLTX2Transformer(LTX2TransformerTesterConfig, ModelTesterMixin):
"""Core model tests for LTX2 Video Transformer."""
class TestLTX2TransformerMemory(LTX2TransformerTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for LTX2 Video Transformer."""
class TestLTX2TransformerTraining(LTX2TransformerTesterConfig, TrainingTesterMixin):
"""Training tests for LTX2 Video Transformer."""
def test_gradient_checkpointing_is_applied(self):
expected_set = {"LTX2VideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
# def test_ltx2_consistency(self, seed=0, dtype=torch.float32):
# torch.manual_seed(seed)
# init_dict, _ = self.prepare_init_args_and_inputs_for_common()
# # Calculate dummy inputs in a custom manner to ensure compatibility with original code
# batch_size = 2
# num_frames = 9
# latent_frames = 2
# text_embedding_dim = 16
# text_seq_len = 16
# fps = 25.0
# sampling_rate = 16000.0
# hop_length = 160.0
# sigma = torch.rand((1,), generator=torch.manual_seed(seed), dtype=dtype, device="cpu") * 1000
# timestep = (sigma * torch.ones((batch_size,), dtype=dtype, device="cpu")).to(device=torch_device)
# num_channels = 4
# latent_height = 4
# latent_width = 4
# hidden_states = torch.randn(
# (batch_size, num_channels, latent_frames, latent_height, latent_width),
# generator=torch.manual_seed(seed),
# dtype=dtype,
# device="cpu",
# )
# # Patchify video latents (with patch_size (1, 1, 1))
# hidden_states = hidden_states.reshape(batch_size, -1, latent_frames, 1, latent_height, 1, latent_width, 1)
# hidden_states = hidden_states.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
# encoder_hidden_states = torch.randn(
# (batch_size, text_seq_len, text_embedding_dim),
# generator=torch.manual_seed(seed),
# dtype=dtype,
# device="cpu",
# )
# audio_num_channels = 2
# num_mel_bins = 2
# latent_length = int((sampling_rate / hop_length / 4) * (num_frames / fps))
# audio_hidden_states = torch.randn(
# (batch_size, audio_num_channels, latent_length, num_mel_bins),
# generator=torch.manual_seed(seed),
# dtype=dtype,
# device="cpu",
# )
# # Patchify audio latents
# audio_hidden_states = audio_hidden_states.transpose(1, 2).flatten(2, 3)
# audio_encoder_hidden_states = torch.randn(
# (batch_size, text_seq_len, text_embedding_dim),
# generator=torch.manual_seed(seed),
# dtype=dtype,
# device="cpu",
# )
# inputs_dict = {
# "hidden_states": hidden_states.to(device=torch_device),
# "audio_hidden_states": audio_hidden_states.to(device=torch_device),
# "encoder_hidden_states": encoder_hidden_states.to(device=torch_device),
# "audio_encoder_hidden_states": audio_encoder_hidden_states.to(device=torch_device),
# "timestep": timestep,
# "num_frames": latent_frames,
# "height": latent_height,
# "width": latent_width,
# "audio_num_frames": num_frames,
# "fps": 25.0,
# }
# model = self.model_class.from_pretrained(
# "diffusers-internal-dev/dummy-ltx2",
# subfolder="transformer",
# device_map="cpu",
# )
# # torch.manual_seed(seed)
# # model = self.model_class(**init_dict)
# model.to(torch_device)
# model.eval()
# with attention_backend("native"):
# with torch.no_grad():
# output = model(**inputs_dict)
# video_output, audio_output = output.to_tuple()
# self.assertIsNotNone(video_output)
# self.assertIsNotNone(audio_output)
# # input & output have to have the same shape
# video_expected_shape = (batch_size, latent_frames * latent_height * latent_width, num_channels)
# self.assertEqual(video_output.shape, video_expected_shape, "Video input and output shapes do not match")
# audio_expected_shape = (batch_size, latent_length, audio_num_channels * num_mel_bins)
# self.assertEqual(audio_output.shape, audio_expected_shape, "Audio input and output shapes do not match")
# # Check against expected slice
# # fmt: off
# video_expected_slice = torch.tensor([0.4783, 1.6954, -1.2092, 0.1762, 0.7801, 1.2025, -1.4525, -0.2721, 0.3354, 1.9144, -1.5546, 0.0831, 0.4391, 1.7012, -1.7373, -0.2676])
# audio_expected_slice = torch.tensor([-0.4236, 0.4750, 0.3901, -0.4339, -0.2782, 0.4357, 0.4526, -0.3927, -0.0980, 0.4870, 0.3964, -0.3169, -0.3974, 0.4408, 0.3809, -0.4692])
# # fmt: on
# video_output_flat = video_output.cpu().flatten().float()
# video_generated_slice = torch.cat([video_output_flat[:8], video_output_flat[-8:]])
# self.assertTrue(torch.allclose(video_generated_slice, video_expected_slice, atol=1e-4))
# audio_output_flat = audio_output.cpu().flatten().float()
# audio_generated_slice = torch.cat([audio_output_flat[:8], audio_output_flat[-8:]])
# self.assertTrue(torch.allclose(audio_generated_slice, audio_expected_slice, atol=1e-4))
super().test_gradient_checkpointing_is_applied(expected_set={"LTX2VideoTransformer3DModel"})
class LTX2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = LTX2VideoTransformer3DModel
class TestLTX2TransformerAttention(LTX2TransformerTesterConfig, AttentionTesterMixin):
"""Attention processor tests for LTX2 Video Transformer."""
def prepare_init_args_and_inputs_for_common(self):
return LTX2TransformerTests().prepare_init_args_and_inputs_for_common()
@pytest.mark.skip(
"LTX2Attention does not set is_cross_attention, so fuse_projections tries to fuse Q+K+V together even for cross-attention modules with different input dimensions."
)
def test_fuse_unfuse_qkv_projections(self, atol=1e-3, rtol=0):
pass
class TestLTX2TransformerCompile(LTX2TransformerTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for LTX2 Video Transformer."""
# TODO: Add pretrained_model_name_or_path once a tiny LTX2 model is available on the Hub
# class TestLTX2TransformerBitsAndBytes(LTX2TransformerTesterConfig, BitsAndBytesTesterMixin):
# """BitsAndBytes quantization tests for LTX2 Video Transformer."""
# TODO: Add pretrained_model_name_or_path once a tiny LTX2 model is available on the Hub
# class TestLTX2TransformerTorchAo(LTX2TransformerTesterConfig, TorchAoTesterMixin):
# """TorchAo quantization tests for LTX2 Video Transformer."""