Compare commits

..

1 Commits

Author SHA1 Message Date
DN6
c3fccd4489 update 2026-02-08 14:03:36 +05:30
8 changed files with 119 additions and 482 deletions

View File

@@ -106,6 +106,8 @@ video, audio = pipe(
output_type="np",
return_dict=False,
)
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)
encode_video(
video[0],
@@ -183,6 +185,8 @@ video, audio = pipe(
output_type="np",
return_dict=False,
)
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)
encode_video(
video[0],

View File

@@ -29,8 +29,31 @@ text_encoder = AutoModel.from_pretrained(
)
```
## Custom models
[`AutoModel`] also loads models from the [Hub](https://huggingface.co/models) that aren't included in Diffusers. Set `trust_remote_code=True` in [`AutoModel.from_pretrained`] to load custom models.
A custom model repository needs a Python module with the model class, and a `config.json` with an `auto_map` entry that maps `"AutoModel"` to `"module_file.ClassName"`.
```
custom/custom-transformer-model/
├── config.json
├── my_model.py
└── diffusion_pytorch_model.safetensors
```
The `config.json` includes the `auto_map` field pointing to the custom class.
```json
{
"auto_map": {
"AutoModel": "my_model.MyCustomModel"
}
}
```
Then load it with `trust_remote_code=True`.
```py
import torch
from diffusers import AutoModel
@@ -40,7 +63,39 @@ transformer = AutoModel.from_pretrained(
)
```
For a real-world example, [Overworld/Waypoint-1-Small](https://huggingface.co/Overworld/Waypoint-1-Small/tree/main/transformer) hosts a custom `WorldModel` class across several modules in its `transformer` subfolder.
```
transformer/
├── config.json # auto_map: "model.WorldModel"
├── model.py
├── attn.py
├── nn.py
├── cache.py
├── quantize.py
├── __init__.py
└── diffusion_pytorch_model.safetensors
```
```py
import torch
from diffusers import AutoModel
transformer = AutoModel.from_pretrained(
"Overworld/Waypoint-1-Small", subfolder="transformer", trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="cuda"
)
```
If the custom model inherits from the [`ModelMixin`] class, it gets access to the same features as Diffusers model classes, like [regional compilation](../optimization/fp16#regional-compilation) and [group offloading](../optimization/memory#group-offloading).
> [!WARNING]
> As a precaution with `trust_remote_code=True`, pass a commit hash to the `revision` argument in [`AutoModel.from_pretrained`] to make sure the code hasn't been updated with new malicious code (unless you fully trust the model owners).
>
> ```py
> transformer = AutoModel.from_pretrained(
> "Overworld/Waypoint-1-Small", subfolder="transformer", trust_remote_code=True, revision="a3d8cb2"
> )
> ```
> [!NOTE]
> Learn more about implementing custom models in the [Community components](../using-diffusers/custom_pipeline_overview#community-components) guide.

View File

@@ -264,10 +264,6 @@ class _HubKernelConfig:
function_attr: str
revision: Optional[str] = None
kernel_fn: Optional[Callable] = None
wrapped_forward_attr: Optional[str] = None
wrapped_backward_attr: Optional[str] = None
wrapped_forward_fn: Optional[Callable] = None
wrapped_backward_fn: Optional[Callable] = None
# Registry for hub-based attention kernels
@@ -282,11 +278,7 @@ _HUB_KERNELS_REGISTRY: Dict["AttentionBackendName", _HubKernelConfig] = {
# revision="fake-ops-return-probs",
),
AttentionBackendName.FLASH_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn2",
function_attr="flash_attn_func",
revision=None,
wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_forward",
wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_backward",
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_func", revision=None
),
AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", revision=None
@@ -611,39 +603,22 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
# ===== Helpers for downloading kernels =====
def _resolve_kernel_attr(module, attr_path: str):
target = module
for attr in attr_path.split("."):
if not hasattr(target, attr):
raise AttributeError(f"Kernel module '{module.__name__}' does not define attribute path '{attr_path}'.")
target = getattr(target, attr)
return target
def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
if backend not in _HUB_KERNELS_REGISTRY:
return
config = _HUB_KERNELS_REGISTRY[backend]
needs_kernel = config.kernel_fn is None
needs_wrapped_forward = config.wrapped_forward_attr is not None and config.wrapped_forward_fn is None
needs_wrapped_backward = config.wrapped_backward_attr is not None and config.wrapped_backward_fn is None
if not (needs_kernel or needs_wrapped_forward or needs_wrapped_backward):
if config.kernel_fn is not None:
return
try:
from kernels import get_kernel
kernel_module = get_kernel(config.repo_id, revision=config.revision)
if needs_kernel:
config.kernel_fn = _resolve_kernel_attr(kernel_module, config.function_attr)
kernel_func = getattr(kernel_module, config.function_attr)
if needs_wrapped_forward:
config.wrapped_forward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_forward_attr)
if needs_wrapped_backward:
config.wrapped_backward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_backward_attr)
# Cache the downloaded kernel function in the config object
config.kernel_fn = kernel_func
except Exception as e:
logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}")
@@ -1094,231 +1069,6 @@ def _flash_attention_backward_op(
return grad_query, grad_key, grad_value
def _flash_attention_hub_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: Optional["ParallelConfig"] = None,
):
if attn_mask is not None:
raise ValueError("`attn_mask` is not yet supported for flash-attn hub kernels.")
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for flash-attn hub kernels.")
config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB]
wrapped_forward_fn = config.wrapped_forward_fn
wrapped_backward_fn = config.wrapped_backward_fn
if wrapped_forward_fn is None or wrapped_backward_fn is None:
raise RuntimeError(
"Flash attention hub kernels must expose `_wrapped_flash_attn_forward` and `_wrapped_flash_attn_backward` "
"for context parallel execution."
)
if scale is None:
scale = query.shape[-1] ** (-0.5)
window_size = (-1, -1)
softcap = 0.0
alibi_slopes = None
deterministic = False
grad_enabled = any(x.requires_grad for x in (query, key, value))
if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1):
dropout_p = dropout_p if dropout_p > 0 else 1e-30
with torch.set_grad_enabled(grad_enabled):
out, lse, S_dmask, rng_state = wrapped_forward_fn(
query,
key,
value,
dropout_p,
scale,
is_causal,
window_size[0],
window_size[1],
softcap,
alibi_slopes,
return_lse,
)
lse = lse.permute(0, 2, 1).contiguous()
if _save_ctx:
ctx.save_for_backward(query, key, value, out, lse, rng_state)
ctx.dropout_p = dropout_p
ctx.scale = scale
ctx.is_causal = is_causal
ctx.window_size = window_size
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return (out, lse) if return_lse else out
def _flash_attention_hub_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
**kwargs,
):
config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB]
wrapped_backward_fn = config.wrapped_backward_fn
if wrapped_backward_fn is None:
raise RuntimeError(
"Flash attention hub kernels must expose `_wrapped_flash_attn_backward` for context parallel execution."
)
query, key, value, out, lse, rng_state = ctx.saved_tensors
grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value)
_ = wrapped_backward_fn(
grad_out,
query,
key,
value,
out,
lse,
grad_query,
grad_key,
grad_value,
ctx.dropout_p,
ctx.scale,
ctx.is_causal,
ctx.window_size[0],
ctx.window_size[1],
ctx.softcap,
ctx.alibi_slopes,
ctx.deterministic,
rng_state,
)
grad_query = grad_query[..., : grad_out.shape[-1]]
grad_key = grad_key[..., : grad_out.shape[-1]]
grad_value = grad_value[..., : grad_out.shape[-1]]
return grad_query, grad_key, grad_value
def _flash_attention_3_hub_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: Optional["ParallelConfig"] = None,
*,
window_size: Tuple[int, int] = (-1, -1),
softcap: float = 0.0,
num_splits: int = 1,
pack_gqa: Optional[bool] = None,
deterministic: bool = False,
sm_margin: int = 0,
):
if attn_mask is not None:
raise ValueError("`attn_mask` is not yet supported for flash-attn 3 hub kernels.")
if dropout_p != 0.0:
raise ValueError("`dropout_p` is not yet supported for flash-attn 3 hub kernels.")
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for flash-attn 3 hub kernels.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=num_splits,
pack_gqa=pack_gqa,
deterministic=deterministic,
sm_margin=sm_margin,
return_attn_probs=return_lse,
)
lse = None
if return_lse:
out, lse = out
lse = lse.permute(0, 2, 1).contiguous()
if _save_ctx:
ctx.save_for_backward(query, key, value)
ctx.scale = scale
ctx.is_causal = is_causal
ctx._hub_kernel = func
return (out, lse) if return_lse else out
def _flash_attention_3_hub_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
window_size: Tuple[int, int] = (-1, -1),
softcap: float = 0.0,
num_splits: int = 1,
pack_gqa: Optional[bool] = None,
deterministic: bool = False,
sm_margin: int = 0,
):
query, key, value = ctx.saved_tensors
kernel_fn = ctx._hub_kernel
with torch.enable_grad():
query_r = query.detach().requires_grad_(True)
key_r = key.detach().requires_grad_(True)
value_r = value.detach().requires_grad_(True)
out = kernel_fn(
q=query_r,
k=key_r,
v=value_r,
softmax_scale=ctx.scale,
causal=ctx.is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=num_splits,
pack_gqa=pack_gqa,
deterministic=deterministic,
sm_margin=sm_margin,
return_attn_probs=False,
)
if isinstance(out, tuple):
out = out[0]
grad_query, grad_key, grad_value = torch.autograd.grad(
out,
(query_r, key_r, value_r),
grad_out,
retain_graph=False,
allow_unused=False,
)
return grad_query, grad_key, grad_value
def _sage_attention_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
@@ -1357,46 +1107,6 @@ def _sage_attention_forward_op(
return (out, lse) if return_lse else out
def _sage_attention_hub_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: Optional["ParallelConfig"] = None,
):
if attn_mask is not None:
raise ValueError("`attn_mask` is not yet supported for Sage attention.")
if dropout_p > 0.0:
raise ValueError("`dropout_p` is not yet supported for Sage attention.")
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for Sage attention.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
tensor_layout="NHD",
is_causal=is_causal,
sm_scale=scale,
return_lse=return_lse,
)
lse = None
if return_lse:
out, lse, *_ = out
lse = lse.permute(0, 2, 1).contiguous()
return (out, lse) if return_lse else out
def _sage_attention_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
@@ -2230,7 +1940,7 @@ def _flash_attention(
@_AttentionBackendRegistry.register(
AttentionBackendName.FLASH_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=True,
supports_context_parallel=False,
)
def _flash_attention_hub(
query: torch.Tensor,
@@ -2248,35 +1958,17 @@ def _flash_attention_hub(
raise ValueError("`attn_mask` is not supported for flash-attn 2.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn
if _parallel_config is None:
out = func(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_attn_probs=return_lse,
)
if return_lse:
out, lse, *_ = out
else:
out = _templated_context_parallel_attention(
query,
key,
value,
None,
dropout_p,
is_causal,
scale,
False,
return_lse,
forward_op=_flash_attention_hub_forward_op,
backward_op=_flash_attention_hub_backward_op,
_parallel_config=_parallel_config,
)
if return_lse:
out, lse = out
out = func(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_attn_probs=return_lse,
)
if return_lse:
out, lse, *_ = out
return (out, lse) if return_lse else out
@@ -2423,7 +2115,7 @@ def _flash_attention_3(
@_AttentionBackendRegistry.register(
AttentionBackendName._FLASH_3_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=True,
supports_context_parallel=False,
)
def _flash_attention_3_hub(
query: torch.Tensor,
@@ -2438,68 +2130,33 @@ def _flash_attention_3_hub(
return_attn_probs: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if _parallel_config:
raise NotImplementedError(f"{AttentionBackendName._FLASH_3_HUB.value} is not implemented for parallelism yet.")
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for flash-attn 3.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
if _parallel_config is None:
out = func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
return_attn_probs=return_attn_probs,
)
return (out[0], out[1]) if return_attn_probs else out
forward_op = functools.partial(
_flash_attention_3_hub_forward_op,
out = func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
return_attn_probs=return_attn_probs,
)
backward_op = functools.partial(
_flash_attention_3_hub_backward_op,
window_size=window_size,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
)
out = _templated_context_parallel_attention(
query,
key,
value,
None,
0.0,
is_causal,
scale,
False,
return_attn_probs,
forward_op=forward_op,
backward_op=backward_op,
_parallel_config=_parallel_config,
)
if return_attn_probs:
out, lse = out
return out, lse
return out
# When `return_attn_probs` is True, the above returns a tuple of
# actual outputs and lse.
return (out[0], out[1]) if return_attn_probs else out
@_AttentionBackendRegistry.register(
@@ -3130,7 +2787,7 @@ def _sage_attention(
@_AttentionBackendRegistry.register(
AttentionBackendName.SAGE_HUB,
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=True,
supports_context_parallel=False,
)
def _sage_attention_hub(
query: torch.Tensor,
@@ -3158,23 +2815,6 @@ def _sage_attention_hub(
)
if return_lse:
out, lse, *_ = out
else:
out = _templated_context_parallel_attention(
query,
key,
value,
None,
0.0,
is_causal,
scale,
False,
return_lse,
forward_op=_sage_attention_hub_forward_op,
backward_op=_sage_attention_backward_op,
_parallel_config=_parallel_config,
)
if return_lse:
out, lse = out
return (out, lse) if return_lse else out

View File

@@ -1598,11 +1598,11 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
else:
logger.warning(f"`blocks` is `None`, no default blocks class found for {self.__class__.__name__}")
self._blocks = blocks
self.blocks = blocks
self._components_manager = components_manager
self._collection = collection
self._component_specs = {spec.name: deepcopy(spec) for spec in self._blocks.expected_components}
self._config_specs = {spec.name: deepcopy(spec) for spec in self._blocks.expected_configs}
self._component_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_components}
self._config_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_configs}
# update component_specs and config_specs based on modular_model_index.json
if modular_config_dict is not None:
@@ -1649,9 +1649,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
for name, config_spec in self._config_specs.items():
default_configs[name] = config_spec.default
self.register_to_config(**default_configs)
self.register_to_config(
_blocks_class_name=self._blocks.__class__.__name__ if self._blocks is not None else None
)
self.register_to_config(_blocks_class_name=self.blocks.__class__.__name__ if self.blocks is not None else None)
@property
def default_call_parameters(self) -> Dict[str, Any]:
@@ -1660,7 +1658,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
- Dictionary mapping input names to their default values
"""
params = {}
for input_param in self._blocks.inputs:
for input_param in self.blocks.inputs:
params[input_param.name] = input_param.default
return params
@@ -1831,15 +1829,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
Returns:
- The docstring of the pipeline blocks
"""
return self._blocks.doc
@property
def blocks(self) -> ModularPipelineBlocks:
"""
Returns:
- A copy of the pipeline blocks
"""
return deepcopy(self._blocks)
return self.blocks.doc
def register_components(self, **kwargs):
"""
@@ -2575,7 +2565,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
)
def set_progress_bar_config(self, **kwargs):
for sub_block_name, sub_block in self._blocks.sub_blocks.items():
for sub_block_name, sub_block in self.blocks.sub_blocks.items():
if hasattr(sub_block, "set_progress_bar_config"):
sub_block.set_progress_bar_config(**kwargs)
@@ -2629,7 +2619,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
# Add inputs to state, using defaults if not provided in the kwargs or the state
# if same input already in the state, will override it if provided in the kwargs
for expected_input_param in self._blocks.inputs:
for expected_input_param in self.blocks.inputs:
name = expected_input_param.name
default = expected_input_param.default
kwargs_type = expected_input_param.kwargs_type
@@ -2648,9 +2638,9 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
# Run the pipeline
with torch.no_grad():
try:
_, state = self._blocks(self, state)
_, state = self.blocks(self, state)
except Exception:
error_msg = f"Error in block: ({self._blocks.__class__.__name__}):\n"
error_msg = f"Error in block: ({self.blocks.__class__.__name__}):\n"
logger.error(error_msg)
raise

View File

@@ -13,20 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Iterator
from fractions import Fraction
from itertools import chain
from typing import List, Optional, Union
from typing import Optional
import numpy as np
import PIL.Image
import torch
from tqdm import tqdm
from ...utils import get_logger, is_av_available
logger = get_logger(__name__) # pylint: disable=invalid-name
from ...utils import is_av_available
_CAN_USE_AV = is_av_available()
@@ -109,59 +101,11 @@ def _write_audio(
def encode_video(
video: Union[List[PIL.Image.Image], np.ndarray, torch.Tensor, Iterator[torch.Tensor]],
fps: int,
audio: Optional[torch.Tensor],
audio_sample_rate: Optional[int],
output_path: str,
video_chunks_number: int = 1,
video: torch.Tensor, fps: int, audio: Optional[torch.Tensor], audio_sample_rate: Optional[int], output_path: str
) -> None:
"""
Encodes a video with audio using the PyAV library. Based on code from the original LTX-2 repo:
https://github.com/Lightricks/LTX-2/blob/4f410820b198e05074a1e92de793e3b59e9ab5a0/packages/ltx-pipelines/src/ltx_pipelines/utils/media_io.py#L182
video_np = video.cpu().numpy()
Args:
video (`List[PIL.Image.Image]` or `np.ndarray` or `torch.Tensor`):
A video tensor of shape [frames, height, width, channels] with integer pixel values in [0, 255]. If the
input is a `np.ndarray`, it is expected to be a float array with values in [0, 1] (which is what pipelines
usually return with `output_type="np"`).
fps (`int`)
The frames per second (FPS) of the encoded video.
audio (`torch.Tensor`, *optional*):
An audio waveform of shape [audio_channels, samples].
audio_sample_rate: (`int`, *optional*):
The sampling rate of the audio waveform. For LTX 2, this is typically 24000 (24 kHz).
output_path (`str`):
The path to save the encoded video to.
video_chunks_number (`int`, *optional*, defaults to `1`):
The number of chunks to split the video into for encoding. Each chunk will be encoded separately. The
number of chunks to use often depends on the tiling config for the video VAE.
"""
if isinstance(video, list) and isinstance(video[0], PIL.Image.Image):
# Pipeline output_type="pil"; assumes each image is in "RGB" mode
video_frames = [np.array(frame) for frame in video]
video = np.stack(video_frames, axis=0)
video = torch.from_numpy(video)
elif isinstance(video, np.ndarray):
# Pipeline output_type="np"
is_denormalized = np.logical_and(np.zeros_like(video) <= video, video <= np.ones_like(video))
if np.all(is_denormalized):
video = (video * 255).round().astype("uint8")
else:
logger.warning(
"Supplied `numpy.ndarray` does not have values in [0, 1]. The values will be assumed to be pixel "
"values in [0, ..., 255] and will be used as is."
)
video = torch.from_numpy(video)
if isinstance(video, torch.Tensor):
# Split into video_chunks_number along the frame dimension
video = torch.tensor_split(video, video_chunks_number, dim=0)
video = iter(video)
first_chunk = next(video)
_, height, width, _ = first_chunk.shape
_, height, width, _ = video_np.shape
container = av.open(output_path, mode="w")
stream = container.add_stream("libx264", rate=int(fps))
@@ -175,12 +119,10 @@ def encode_video(
audio_stream = _prepare_audio_stream(container, audio_sample_rate)
for video_chunk in tqdm(chain([first_chunk], video), total=video_chunks_number, desc="Encoding video chunks"):
video_chunk_cpu = video_chunk.to("cpu").numpy()
for frame_array in video_chunk_cpu:
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
for packet in stream.encode(frame):
container.mux(packet)
for frame_array in video_np:
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
for packet in stream.encode(frame):
container.mux(packet)
# Flush encoder
for packet in stream.encode():

View File

@@ -69,6 +69,8 @@ EXAMPLE_DOC_STRING = """
... output_type="np",
... return_dict=False,
... )
>>> video = (video * 255).round().astype("uint8")
>>> video = torch.from_numpy(video)
>>> encode_video(
... video[0],

View File

@@ -75,6 +75,8 @@ EXAMPLE_DOC_STRING = """
... output_type="np",
... return_dict=False,
... )
>>> video = (video * 255).round().astype("uint8")
>>> video = torch.from_numpy(video)
>>> encode_video(
... video[0],

View File

@@ -76,6 +76,8 @@ EXAMPLE_DOC_STRING = """
... output_type="np",
... return_dict=False,
... )[0]
>>> video = (video * 255).round().astype("uint8")
>>> video = torch.from_numpy(video)
>>> encode_video(
... video[0],