Compare commits

...

5 Commits

Author SHA1 Message Date
Daniel Gu
7354055077 make style and make quality 2026-01-30 07:52:18 +01:00
Daniel Gu
cd60b3d151 Address review comments 2026-01-30 07:49:24 +01:00
Daniel Gu
857735f15d Fix comment 2026-01-30 02:26:36 +01:00
Daniel Gu
2e18d2c51a Update examples to use improved encode_video function 2026-01-30 02:10:51 +01:00
Daniel Gu
d5d2910654 Support different pipeline outputs for LTX 2 encode_video 2026-01-29 09:37:06 +01:00
5 changed files with 61 additions and 18 deletions

View File

@@ -106,8 +106,6 @@ video, audio = pipe(
output_type="np",
return_dict=False,
)
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)
encode_video(
video[0],
@@ -185,8 +183,6 @@ video, audio = pipe(
output_type="np",
return_dict=False,
)
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)
encode_video(
video[0],

View File

@@ -13,10 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Generator, Iterator
from fractions import Fraction
from typing import Optional
from typing import List, Optional, Tuple, Union
import numpy as np
import PIL.Image
import torch
from tqdm import tqdm
from ...utils import is_av_available
@@ -101,11 +105,52 @@ def _write_audio(
def encode_video(
video: torch.Tensor, fps: int, audio: Optional[torch.Tensor], audio_sample_rate: Optional[int], output_path: str
video: Union[List[PIL.Image.Image], np.ndarray, torch.Tensor, Iterator[torch.Tensor]],
fps: int,
audio: Optional[torch.Tensor],
audio_sample_rate: Optional[int],
output_path: str,
video_chunks_number: int = 1,
) -> None:
video_np = video.cpu().numpy()
"""
Encodes a video with audio using the PyAV library. Based on code from the original LTX-2 repo:
https://github.com/Lightricks/LTX-2/blob/4f410820b198e05074a1e92de793e3b59e9ab5a0/packages/ltx-pipelines/src/ltx_pipelines/utils/media_io.py#L182
_, height, width, _ = video_np.shape
Args:
video (`List[PIL.Image.Image]` or `np.ndarray` or `torch.Tensor`):
A video tensor of shape [frames, height, width, channels] with integer pixel values in [0, 255]. If the
input is a `np.ndarray`, it is expected to be a float array with values in [0, 1] (which is what pipelines
usually return with `output_type="np"`).
fps (`int`)
The frames per second (FPS) of the encoded video.
audio (`torch.Tensor`, *optional*):
An audio waveform of shape [audio_channels, samples].
audio_sample_rate: (`int`, *optional*):
The sampling rate of the audio waveform. For LTX 2, this is typically 24000 (24 kHz).
output_path (`str`):
The path to save the encoded video to.
video_chunks_number (`int`, *optional*, defaults to `1`):
The number of chunks to split the video into for encoding. Each chunk will be encoded separately. The
number of chunks to use often depends on the tiling config for the video VAE.
"""
if isinstance(video, list) and isinstance(video[0], PIL.Image.Image):
# Pipeline output_type="pil"
video_frames = [np.array(frame) for frame in video]
video = np.stack(video_frames, axis=0)
video = torch.from_numpy(video)
elif isinstance(video, np.ndarray):
# Pipeline output_type="np"
is_denormalized = np.logical_and(np.zeros_like(video) <= video, video <= np.ones_like(video))
if np.all(is_denormalized):
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)
if isinstance(video, torch.Tensor):
video = iter([video])
first_chunk = next(video)
_, height, width, _ = first_chunk.shape
container = av.open(output_path, mode="w")
stream = container.add_stream("libx264", rate=int(fps))
@@ -119,10 +164,18 @@ def encode_video(
audio_stream = _prepare_audio_stream(container, audio_sample_rate)
for frame_array in video_np:
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
for packet in stream.encode(frame):
container.mux(packet)
def all_tiles(
first_chunk: torch.Tensor, tiles_generator: Generator[Tuple[torch.Tensor, int], None, None]
) -> Generator[Tuple[torch.Tensor, int], None, None]:
yield first_chunk
yield from tiles_generator
for video_chunk in tqdm(all_tiles(first_chunk, video), total=video_chunks_number):
video_chunk_cpu = video_chunk.to("cpu").numpy()
for frame_array in video_chunk_cpu:
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
for packet in stream.encode(frame):
container.mux(packet)
# Flush encoder
for packet in stream.encode():

View File

@@ -69,8 +69,6 @@ EXAMPLE_DOC_STRING = """
... output_type="np",
... return_dict=False,
... )
>>> video = (video * 255).round().astype("uint8")
>>> video = torch.from_numpy(video)
>>> encode_video(
... video[0],

View File

@@ -75,8 +75,6 @@ EXAMPLE_DOC_STRING = """
... output_type="np",
... return_dict=False,
... )
>>> video = (video * 255).round().astype("uint8")
>>> video = torch.from_numpy(video)
>>> encode_video(
... video[0],

View File

@@ -76,8 +76,6 @@ EXAMPLE_DOC_STRING = """
... output_type="np",
... return_dict=False,
... )[0]
>>> video = (video * 255).round().astype("uint8")
>>> video = torch.from_numpy(video)
>>> encode_video(
... video[0],