mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-30 23:45:01 +08:00
Compare commits
5 Commits
main
...
ltx2-impro
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7354055077 | ||
|
|
cd60b3d151 | ||
|
|
857735f15d | ||
|
|
2e18d2c51a | ||
|
|
d5d2910654 |
@@ -106,8 +106,6 @@ video, audio = pipe(
|
||||
output_type="np",
|
||||
return_dict=False,
|
||||
)
|
||||
video = (video * 255).round().astype("uint8")
|
||||
video = torch.from_numpy(video)
|
||||
|
||||
encode_video(
|
||||
video[0],
|
||||
@@ -185,8 +183,6 @@ video, audio = pipe(
|
||||
output_type="np",
|
||||
return_dict=False,
|
||||
)
|
||||
video = (video * 255).round().astype("uint8")
|
||||
video = torch.from_numpy(video)
|
||||
|
||||
encode_video(
|
||||
video[0],
|
||||
|
||||
@@ -13,10 +13,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Generator, Iterator
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from ...utils import is_av_available
|
||||
|
||||
@@ -101,11 +105,52 @@ def _write_audio(
|
||||
|
||||
|
||||
def encode_video(
|
||||
video: torch.Tensor, fps: int, audio: Optional[torch.Tensor], audio_sample_rate: Optional[int], output_path: str
|
||||
video: Union[List[PIL.Image.Image], np.ndarray, torch.Tensor, Iterator[torch.Tensor]],
|
||||
fps: int,
|
||||
audio: Optional[torch.Tensor],
|
||||
audio_sample_rate: Optional[int],
|
||||
output_path: str,
|
||||
video_chunks_number: int = 1,
|
||||
) -> None:
|
||||
video_np = video.cpu().numpy()
|
||||
"""
|
||||
Encodes a video with audio using the PyAV library. Based on code from the original LTX-2 repo:
|
||||
https://github.com/Lightricks/LTX-2/blob/4f410820b198e05074a1e92de793e3b59e9ab5a0/packages/ltx-pipelines/src/ltx_pipelines/utils/media_io.py#L182
|
||||
|
||||
_, height, width, _ = video_np.shape
|
||||
Args:
|
||||
video (`List[PIL.Image.Image]` or `np.ndarray` or `torch.Tensor`):
|
||||
A video tensor of shape [frames, height, width, channels] with integer pixel values in [0, 255]. If the
|
||||
input is a `np.ndarray`, it is expected to be a float array with values in [0, 1] (which is what pipelines
|
||||
usually return with `output_type="np"`).
|
||||
fps (`int`)
|
||||
The frames per second (FPS) of the encoded video.
|
||||
audio (`torch.Tensor`, *optional*):
|
||||
An audio waveform of shape [audio_channels, samples].
|
||||
audio_sample_rate: (`int`, *optional*):
|
||||
The sampling rate of the audio waveform. For LTX 2, this is typically 24000 (24 kHz).
|
||||
output_path (`str`):
|
||||
The path to save the encoded video to.
|
||||
video_chunks_number (`int`, *optional*, defaults to `1`):
|
||||
The number of chunks to split the video into for encoding. Each chunk will be encoded separately. The
|
||||
number of chunks to use often depends on the tiling config for the video VAE.
|
||||
"""
|
||||
if isinstance(video, list) and isinstance(video[0], PIL.Image.Image):
|
||||
# Pipeline output_type="pil"
|
||||
video_frames = [np.array(frame) for frame in video]
|
||||
video = np.stack(video_frames, axis=0)
|
||||
video = torch.from_numpy(video)
|
||||
elif isinstance(video, np.ndarray):
|
||||
# Pipeline output_type="np"
|
||||
is_denormalized = np.logical_and(np.zeros_like(video) <= video, video <= np.ones_like(video))
|
||||
if np.all(is_denormalized):
|
||||
video = (video * 255).round().astype("uint8")
|
||||
video = torch.from_numpy(video)
|
||||
|
||||
if isinstance(video, torch.Tensor):
|
||||
video = iter([video])
|
||||
|
||||
first_chunk = next(video)
|
||||
|
||||
_, height, width, _ = first_chunk.shape
|
||||
|
||||
container = av.open(output_path, mode="w")
|
||||
stream = container.add_stream("libx264", rate=int(fps))
|
||||
@@ -119,10 +164,18 @@ def encode_video(
|
||||
|
||||
audio_stream = _prepare_audio_stream(container, audio_sample_rate)
|
||||
|
||||
for frame_array in video_np:
|
||||
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
|
||||
for packet in stream.encode(frame):
|
||||
container.mux(packet)
|
||||
def all_tiles(
|
||||
first_chunk: torch.Tensor, tiles_generator: Generator[Tuple[torch.Tensor, int], None, None]
|
||||
) -> Generator[Tuple[torch.Tensor, int], None, None]:
|
||||
yield first_chunk
|
||||
yield from tiles_generator
|
||||
|
||||
for video_chunk in tqdm(all_tiles(first_chunk, video), total=video_chunks_number):
|
||||
video_chunk_cpu = video_chunk.to("cpu").numpy()
|
||||
for frame_array in video_chunk_cpu:
|
||||
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
|
||||
for packet in stream.encode(frame):
|
||||
container.mux(packet)
|
||||
|
||||
# Flush encoder
|
||||
for packet in stream.encode():
|
||||
|
||||
@@ -69,8 +69,6 @@ EXAMPLE_DOC_STRING = """
|
||||
... output_type="np",
|
||||
... return_dict=False,
|
||||
... )
|
||||
>>> video = (video * 255).round().astype("uint8")
|
||||
>>> video = torch.from_numpy(video)
|
||||
|
||||
>>> encode_video(
|
||||
... video[0],
|
||||
|
||||
@@ -75,8 +75,6 @@ EXAMPLE_DOC_STRING = """
|
||||
... output_type="np",
|
||||
... return_dict=False,
|
||||
... )
|
||||
>>> video = (video * 255).round().astype("uint8")
|
||||
>>> video = torch.from_numpy(video)
|
||||
|
||||
>>> encode_video(
|
||||
... video[0],
|
||||
|
||||
@@ -76,8 +76,6 @@ EXAMPLE_DOC_STRING = """
|
||||
... output_type="np",
|
||||
... return_dict=False,
|
||||
... )[0]
|
||||
>>> video = (video * 255).round().astype("uint8")
|
||||
>>> video = torch.from_numpy(video)
|
||||
|
||||
>>> encode_video(
|
||||
... video[0],
|
||||
|
||||
Reference in New Issue
Block a user