LTX 2 Improve encode_video by Accepting More Input Types (#13057)

* Support different pipeline outputs for LTX 2 encode_video

* Update examples to use improved encode_video function

* Fix comment

* Address review comments

* make style and make quality

* Have non-iterator video inputs respect video_chunks_number

* make style and make quality

* Add warning when encode_video receives a non-denormalized np.ndarray

* make style and make quality

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
dg845
2026-02-08 19:40:34 -08:00
committed by GitHub
parent 44f4dc0054
commit baaa8d040b
5 changed files with 67 additions and 19 deletions

View File

@@ -106,8 +106,6 @@ video, audio = pipe(
output_type="np",
return_dict=False,
)
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)
encode_video(
video[0],
@@ -185,8 +183,6 @@ video, audio = pipe(
output_type="np",
return_dict=False,
)
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)
encode_video(
video[0],

View File

@@ -13,12 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Iterator
from fractions import Fraction
from typing import Optional
from itertools import chain
from typing import List, Optional, Union
import numpy as np
import PIL.Image
import torch
from tqdm import tqdm
from ...utils import is_av_available
from ...utils import get_logger, is_av_available
logger = get_logger(__name__) # pylint: disable=invalid-name
_CAN_USE_AV = is_av_available()
@@ -101,11 +109,59 @@ def _write_audio(
def encode_video(
video: torch.Tensor, fps: int, audio: Optional[torch.Tensor], audio_sample_rate: Optional[int], output_path: str
video: Union[List[PIL.Image.Image], np.ndarray, torch.Tensor, Iterator[torch.Tensor]],
fps: int,
audio: Optional[torch.Tensor],
audio_sample_rate: Optional[int],
output_path: str,
video_chunks_number: int = 1,
) -> None:
video_np = video.cpu().numpy()
"""
Encodes a video with audio using the PyAV library. Based on code from the original LTX-2 repo:
https://github.com/Lightricks/LTX-2/blob/4f410820b198e05074a1e92de793e3b59e9ab5a0/packages/ltx-pipelines/src/ltx_pipelines/utils/media_io.py#L182
_, height, width, _ = video_np.shape
Args:
video (`List[PIL.Image.Image]` or `np.ndarray` or `torch.Tensor`):
A video tensor of shape [frames, height, width, channels] with integer pixel values in [0, 255]. If the
input is a `np.ndarray`, it is expected to be a float array with values in [0, 1] (which is what pipelines
usually return with `output_type="np"`).
fps (`int`)
The frames per second (FPS) of the encoded video.
audio (`torch.Tensor`, *optional*):
An audio waveform of shape [audio_channels, samples].
audio_sample_rate: (`int`, *optional*):
The sampling rate of the audio waveform. For LTX 2, this is typically 24000 (24 kHz).
output_path (`str`):
The path to save the encoded video to.
video_chunks_number (`int`, *optional*, defaults to `1`):
The number of chunks to split the video into for encoding. Each chunk will be encoded separately. The
number of chunks to use often depends on the tiling config for the video VAE.
"""
if isinstance(video, list) and isinstance(video[0], PIL.Image.Image):
# Pipeline output_type="pil"; assumes each image is in "RGB" mode
video_frames = [np.array(frame) for frame in video]
video = np.stack(video_frames, axis=0)
video = torch.from_numpy(video)
elif isinstance(video, np.ndarray):
# Pipeline output_type="np"
is_denormalized = np.logical_and(np.zeros_like(video) <= video, video <= np.ones_like(video))
if np.all(is_denormalized):
video = (video * 255).round().astype("uint8")
else:
logger.warning(
"Supplied `numpy.ndarray` does not have values in [0, 1]. The values will be assumed to be pixel "
"values in [0, ..., 255] and will be used as is."
)
video = torch.from_numpy(video)
if isinstance(video, torch.Tensor):
# Split into video_chunks_number along the frame dimension
video = torch.tensor_split(video, video_chunks_number, dim=0)
video = iter(video)
first_chunk = next(video)
_, height, width, _ = first_chunk.shape
container = av.open(output_path, mode="w")
stream = container.add_stream("libx264", rate=int(fps))
@@ -119,10 +175,12 @@ def encode_video(
audio_stream = _prepare_audio_stream(container, audio_sample_rate)
for frame_array in video_np:
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
for packet in stream.encode(frame):
container.mux(packet)
for video_chunk in tqdm(chain([first_chunk], video), total=video_chunks_number, desc="Encoding video chunks"):
video_chunk_cpu = video_chunk.to("cpu").numpy()
for frame_array in video_chunk_cpu:
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
for packet in stream.encode(frame):
container.mux(packet)
# Flush encoder
for packet in stream.encode():

View File

@@ -69,8 +69,6 @@ EXAMPLE_DOC_STRING = """
... output_type="np",
... return_dict=False,
... )
>>> video = (video * 255).round().astype("uint8")
>>> video = torch.from_numpy(video)
>>> encode_video(
... video[0],

View File

@@ -75,8 +75,6 @@ EXAMPLE_DOC_STRING = """
... output_type="np",
... return_dict=False,
... )
>>> video = (video * 255).round().astype("uint8")
>>> video = torch.from_numpy(video)
>>> encode_video(
... video[0],

View File

@@ -76,8 +76,6 @@ EXAMPLE_DOC_STRING = """
... output_type="np",
... return_dict=False,
... )[0]
>>> video = (video * 255).round().astype("uint8")
>>> video = torch.from_numpy(video)
>>> encode_video(
... video[0],