From baaa8d040bbbd6df8ec9b7835f3d4dce9abca5aa Mon Sep 17 00:00:00 2001 From: dg845 <58458699+dg845@users.noreply.github.com> Date: Sun, 8 Feb 2026 19:40:34 -0800 Subject: [PATCH] LTX 2 Improve `encode_video` by Accepting More Input Types (#13057) * Support different pipeline outputs for LTX 2 encode_video * Update examples to use improved encode_video function * Fix comment * Address review comments * make style and make quality * Have non-iterator video inputs respect video_chunks_number * make style and make quality * Add warning when encode_video receives a non-denormalized np.ndarray * make style and make quality --------- Co-authored-by: Sayak Paul --- docs/source/en/api/pipelines/ltx2.md | 4 - src/diffusers/pipelines/ltx2/export_utils.py | 76 ++++++++++++++++--- src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 2 - .../ltx2/pipeline_ltx2_image2video.py | 2 - .../ltx2/pipeline_ltx2_latent_upsample.py | 2 - 5 files changed, 67 insertions(+), 19 deletions(-) diff --git a/docs/source/en/api/pipelines/ltx2.md b/docs/source/en/api/pipelines/ltx2.md index 24776b4230..c77efa09f5 100644 --- a/docs/source/en/api/pipelines/ltx2.md +++ b/docs/source/en/api/pipelines/ltx2.md @@ -106,8 +106,6 @@ video, audio = pipe( output_type="np", return_dict=False, ) -video = (video * 255).round().astype("uint8") -video = torch.from_numpy(video) encode_video( video[0], @@ -185,8 +183,6 @@ video, audio = pipe( output_type="np", return_dict=False, ) -video = (video * 255).round().astype("uint8") -video = torch.from_numpy(video) encode_video( video[0], diff --git a/src/diffusers/pipelines/ltx2/export_utils.py b/src/diffusers/pipelines/ltx2/export_utils.py index 0bc7a59db2..347601422c 100644 --- a/src/diffusers/pipelines/ltx2/export_utils.py +++ b/src/diffusers/pipelines/ltx2/export_utils.py @@ -13,12 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Iterator from fractions import Fraction -from typing import Optional +from itertools import chain +from typing import List, Optional, Union +import numpy as np +import PIL.Image import torch +from tqdm import tqdm -from ...utils import is_av_available +from ...utils import get_logger, is_av_available + + +logger = get_logger(__name__) # pylint: disable=invalid-name _CAN_USE_AV = is_av_available() @@ -101,11 +109,59 @@ def _write_audio( def encode_video( - video: torch.Tensor, fps: int, audio: Optional[torch.Tensor], audio_sample_rate: Optional[int], output_path: str + video: Union[List[PIL.Image.Image], np.ndarray, torch.Tensor, Iterator[torch.Tensor]], + fps: int, + audio: Optional[torch.Tensor], + audio_sample_rate: Optional[int], + output_path: str, + video_chunks_number: int = 1, ) -> None: - video_np = video.cpu().numpy() + """ + Encodes a video with audio using the PyAV library. Based on code from the original LTX-2 repo: + https://github.com/Lightricks/LTX-2/blob/4f410820b198e05074a1e92de793e3b59e9ab5a0/packages/ltx-pipelines/src/ltx_pipelines/utils/media_io.py#L182 - _, height, width, _ = video_np.shape + Args: + video (`List[PIL.Image.Image]` or `np.ndarray` or `torch.Tensor`): + A video tensor of shape [frames, height, width, channels] with integer pixel values in [0, 255]. If the + input is a `np.ndarray`, it is expected to be a float array with values in [0, 1] (which is what pipelines + usually return with `output_type="np"`). + fps (`int`) + The frames per second (FPS) of the encoded video. + audio (`torch.Tensor`, *optional*): + An audio waveform of shape [audio_channels, samples]. + audio_sample_rate: (`int`, *optional*): + The sampling rate of the audio waveform. For LTX 2, this is typically 24000 (24 kHz). + output_path (`str`): + The path to save the encoded video to. + video_chunks_number (`int`, *optional*, defaults to `1`): + The number of chunks to split the video into for encoding. Each chunk will be encoded separately. The + number of chunks to use often depends on the tiling config for the video VAE. + """ + if isinstance(video, list) and isinstance(video[0], PIL.Image.Image): + # Pipeline output_type="pil"; assumes each image is in "RGB" mode + video_frames = [np.array(frame) for frame in video] + video = np.stack(video_frames, axis=0) + video = torch.from_numpy(video) + elif isinstance(video, np.ndarray): + # Pipeline output_type="np" + is_denormalized = np.logical_and(np.zeros_like(video) <= video, video <= np.ones_like(video)) + if np.all(is_denormalized): + video = (video * 255).round().astype("uint8") + else: + logger.warning( + "Supplied `numpy.ndarray` does not have values in [0, 1]. The values will be assumed to be pixel " + "values in [0, ..., 255] and will be used as is." + ) + video = torch.from_numpy(video) + + if isinstance(video, torch.Tensor): + # Split into video_chunks_number along the frame dimension + video = torch.tensor_split(video, video_chunks_number, dim=0) + video = iter(video) + + first_chunk = next(video) + + _, height, width, _ = first_chunk.shape container = av.open(output_path, mode="w") stream = container.add_stream("libx264", rate=int(fps)) @@ -119,10 +175,12 @@ def encode_video( audio_stream = _prepare_audio_stream(container, audio_sample_rate) - for frame_array in video_np: - frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24") - for packet in stream.encode(frame): - container.mux(packet) + for video_chunk in tqdm(chain([first_chunk], video), total=video_chunks_number, desc="Encoding video chunks"): + video_chunk_cpu = video_chunk.to("cpu").numpy() + for frame_array in video_chunk_cpu: + frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24") + for packet in stream.encode(frame): + container.mux(packet) # Flush encoder for packet in stream.encode(): diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index a92a7a2c88..cb01159a81 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -69,8 +69,6 @@ EXAMPLE_DOC_STRING = """ ... output_type="np", ... return_dict=False, ... ) - >>> video = (video * 255).round().astype("uint8") - >>> video = torch.from_numpy(video) >>> encode_video( ... video[0], diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py index 04d7ee89c5..c120e1f010 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -75,8 +75,6 @@ EXAMPLE_DOC_STRING = """ ... output_type="np", ... return_dict=False, ... ) - >>> video = (video * 255).round().astype("uint8") - >>> video = torch.from_numpy(video) >>> encode_video( ... video[0], diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py index 340efd10f2..b0db1bdee3 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py @@ -76,8 +76,6 @@ EXAMPLE_DOC_STRING = """ ... output_type="np", ... return_dict=False, ... )[0] - >>> video = (video * 255).round().astype("uint8") - >>> video = torch.from_numpy(video) >>> encode_video( ... video[0],