mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-10 12:55:19 +08:00
Compare commits
1 Commits
enable-cp-
...
automodel-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c3fccd4489 |
@@ -106,6 +106,8 @@ video, audio = pipe(
|
||||
output_type="np",
|
||||
return_dict=False,
|
||||
)
|
||||
video = (video * 255).round().astype("uint8")
|
||||
video = torch.from_numpy(video)
|
||||
|
||||
encode_video(
|
||||
video[0],
|
||||
@@ -183,6 +185,8 @@ video, audio = pipe(
|
||||
output_type="np",
|
||||
return_dict=False,
|
||||
)
|
||||
video = (video * 255).round().astype("uint8")
|
||||
video = torch.from_numpy(video)
|
||||
|
||||
encode_video(
|
||||
video[0],
|
||||
|
||||
@@ -29,8 +29,31 @@ text_encoder = AutoModel.from_pretrained(
|
||||
)
|
||||
```
|
||||
|
||||
## Custom models
|
||||
|
||||
[`AutoModel`] also loads models from the [Hub](https://huggingface.co/models) that aren't included in Diffusers. Set `trust_remote_code=True` in [`AutoModel.from_pretrained`] to load custom models.
|
||||
|
||||
A custom model repository needs a Python module with the model class, and a `config.json` with an `auto_map` entry that maps `"AutoModel"` to `"module_file.ClassName"`.
|
||||
|
||||
```
|
||||
custom/custom-transformer-model/
|
||||
├── config.json
|
||||
├── my_model.py
|
||||
└── diffusion_pytorch_model.safetensors
|
||||
```
|
||||
|
||||
The `config.json` includes the `auto_map` field pointing to the custom class.
|
||||
|
||||
```json
|
||||
{
|
||||
"auto_map": {
|
||||
"AutoModel": "my_model.MyCustomModel"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Then load it with `trust_remote_code=True`.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import AutoModel
|
||||
@@ -40,7 +63,39 @@ transformer = AutoModel.from_pretrained(
|
||||
)
|
||||
```
|
||||
|
||||
For a real-world example, [Overworld/Waypoint-1-Small](https://huggingface.co/Overworld/Waypoint-1-Small/tree/main/transformer) hosts a custom `WorldModel` class across several modules in its `transformer` subfolder.
|
||||
|
||||
```
|
||||
transformer/
|
||||
├── config.json # auto_map: "model.WorldModel"
|
||||
├── model.py
|
||||
├── attn.py
|
||||
├── nn.py
|
||||
├── cache.py
|
||||
├── quantize.py
|
||||
├── __init__.py
|
||||
└── diffusion_pytorch_model.safetensors
|
||||
```
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import AutoModel
|
||||
|
||||
transformer = AutoModel.from_pretrained(
|
||||
"Overworld/Waypoint-1-Small", subfolder="transformer", trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="cuda"
|
||||
)
|
||||
```
|
||||
|
||||
If the custom model inherits from the [`ModelMixin`] class, it gets access to the same features as Diffusers model classes, like [regional compilation](../optimization/fp16#regional-compilation) and [group offloading](../optimization/memory#group-offloading).
|
||||
|
||||
> [!WARNING]
|
||||
> As a precaution with `trust_remote_code=True`, pass a commit hash to the `revision` argument in [`AutoModel.from_pretrained`] to make sure the code hasn't been updated with new malicious code (unless you fully trust the model owners).
|
||||
>
|
||||
> ```py
|
||||
> transformer = AutoModel.from_pretrained(
|
||||
> "Overworld/Waypoint-1-Small", subfolder="transformer", trust_remote_code=True, revision="a3d8cb2"
|
||||
> )
|
||||
> ```
|
||||
|
||||
> [!NOTE]
|
||||
> Learn more about implementing custom models in the [Community components](../using-diffusers/custom_pipeline_overview#community-components) guide.
|
||||
@@ -264,10 +264,6 @@ class _HubKernelConfig:
|
||||
function_attr: str
|
||||
revision: Optional[str] = None
|
||||
kernel_fn: Optional[Callable] = None
|
||||
wrapped_forward_attr: Optional[str] = None
|
||||
wrapped_backward_attr: Optional[str] = None
|
||||
wrapped_forward_fn: Optional[Callable] = None
|
||||
wrapped_backward_fn: Optional[Callable] = None
|
||||
|
||||
|
||||
# Registry for hub-based attention kernels
|
||||
@@ -282,11 +278,7 @@ _HUB_KERNELS_REGISTRY: Dict["AttentionBackendName", _HubKernelConfig] = {
|
||||
# revision="fake-ops-return-probs",
|
||||
),
|
||||
AttentionBackendName.FLASH_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-community/flash-attn2",
|
||||
function_attr="flash_attn_func",
|
||||
revision=None,
|
||||
wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_forward",
|
||||
wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_backward",
|
||||
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_func", revision=None
|
||||
),
|
||||
AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", revision=None
|
||||
@@ -611,39 +603,22 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
||||
|
||||
|
||||
# ===== Helpers for downloading kernels =====
|
||||
def _resolve_kernel_attr(module, attr_path: str):
|
||||
target = module
|
||||
for attr in attr_path.split("."):
|
||||
if not hasattr(target, attr):
|
||||
raise AttributeError(f"Kernel module '{module.__name__}' does not define attribute path '{attr_path}'.")
|
||||
target = getattr(target, attr)
|
||||
return target
|
||||
|
||||
|
||||
def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
|
||||
if backend not in _HUB_KERNELS_REGISTRY:
|
||||
return
|
||||
config = _HUB_KERNELS_REGISTRY[backend]
|
||||
|
||||
needs_kernel = config.kernel_fn is None
|
||||
needs_wrapped_forward = config.wrapped_forward_attr is not None and config.wrapped_forward_fn is None
|
||||
needs_wrapped_backward = config.wrapped_backward_attr is not None and config.wrapped_backward_fn is None
|
||||
|
||||
if not (needs_kernel or needs_wrapped_forward or needs_wrapped_backward):
|
||||
if config.kernel_fn is not None:
|
||||
return
|
||||
|
||||
try:
|
||||
from kernels import get_kernel
|
||||
|
||||
kernel_module = get_kernel(config.repo_id, revision=config.revision)
|
||||
if needs_kernel:
|
||||
config.kernel_fn = _resolve_kernel_attr(kernel_module, config.function_attr)
|
||||
kernel_func = getattr(kernel_module, config.function_attr)
|
||||
|
||||
if needs_wrapped_forward:
|
||||
config.wrapped_forward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_forward_attr)
|
||||
|
||||
if needs_wrapped_backward:
|
||||
config.wrapped_backward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_backward_attr)
|
||||
# Cache the downloaded kernel function in the config object
|
||||
config.kernel_fn = kernel_func
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}")
|
||||
@@ -1094,231 +1069,6 @@ def _flash_attention_backward_op(
|
||||
return grad_query, grad_key, grad_value
|
||||
|
||||
|
||||
def _flash_attention_hub_forward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
enable_gqa: bool = False,
|
||||
return_lse: bool = False,
|
||||
_save_ctx: bool = True,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
):
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not yet supported for flash-attn hub kernels.")
|
||||
if enable_gqa:
|
||||
raise ValueError("`enable_gqa` is not yet supported for flash-attn hub kernels.")
|
||||
|
||||
config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB]
|
||||
wrapped_forward_fn = config.wrapped_forward_fn
|
||||
wrapped_backward_fn = config.wrapped_backward_fn
|
||||
if wrapped_forward_fn is None or wrapped_backward_fn is None:
|
||||
raise RuntimeError(
|
||||
"Flash attention hub kernels must expose `_wrapped_flash_attn_forward` and `_wrapped_flash_attn_backward` "
|
||||
"for context parallel execution."
|
||||
)
|
||||
|
||||
if scale is None:
|
||||
scale = query.shape[-1] ** (-0.5)
|
||||
|
||||
window_size = (-1, -1)
|
||||
softcap = 0.0
|
||||
alibi_slopes = None
|
||||
deterministic = False
|
||||
grad_enabled = any(x.requires_grad for x in (query, key, value))
|
||||
|
||||
if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1):
|
||||
dropout_p = dropout_p if dropout_p > 0 else 1e-30
|
||||
|
||||
with torch.set_grad_enabled(grad_enabled):
|
||||
out, lse, S_dmask, rng_state = wrapped_forward_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
dropout_p,
|
||||
scale,
|
||||
is_causal,
|
||||
window_size[0],
|
||||
window_size[1],
|
||||
softcap,
|
||||
alibi_slopes,
|
||||
return_lse,
|
||||
)
|
||||
lse = lse.permute(0, 2, 1).contiguous()
|
||||
|
||||
if _save_ctx:
|
||||
ctx.save_for_backward(query, key, value, out, lse, rng_state)
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.scale = scale
|
||||
ctx.is_causal = is_causal
|
||||
ctx.window_size = window_size
|
||||
ctx.softcap = softcap
|
||||
ctx.alibi_slopes = alibi_slopes
|
||||
ctx.deterministic = deterministic
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
def _flash_attention_hub_backward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
grad_out: torch.Tensor,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB]
|
||||
wrapped_backward_fn = config.wrapped_backward_fn
|
||||
if wrapped_backward_fn is None:
|
||||
raise RuntimeError(
|
||||
"Flash attention hub kernels must expose `_wrapped_flash_attn_backward` for context parallel execution."
|
||||
)
|
||||
|
||||
query, key, value, out, lse, rng_state = ctx.saved_tensors
|
||||
grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value)
|
||||
|
||||
_ = wrapped_backward_fn(
|
||||
grad_out,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
lse,
|
||||
grad_query,
|
||||
grad_key,
|
||||
grad_value,
|
||||
ctx.dropout_p,
|
||||
ctx.scale,
|
||||
ctx.is_causal,
|
||||
ctx.window_size[0],
|
||||
ctx.window_size[1],
|
||||
ctx.softcap,
|
||||
ctx.alibi_slopes,
|
||||
ctx.deterministic,
|
||||
rng_state,
|
||||
)
|
||||
|
||||
grad_query = grad_query[..., : grad_out.shape[-1]]
|
||||
grad_key = grad_key[..., : grad_out.shape[-1]]
|
||||
grad_value = grad_value[..., : grad_out.shape[-1]]
|
||||
|
||||
return grad_query, grad_key, grad_value
|
||||
|
||||
|
||||
def _flash_attention_3_hub_forward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
enable_gqa: bool = False,
|
||||
return_lse: bool = False,
|
||||
_save_ctx: bool = True,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
*,
|
||||
window_size: Tuple[int, int] = (-1, -1),
|
||||
softcap: float = 0.0,
|
||||
num_splits: int = 1,
|
||||
pack_gqa: Optional[bool] = None,
|
||||
deterministic: bool = False,
|
||||
sm_margin: int = 0,
|
||||
):
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not yet supported for flash-attn 3 hub kernels.")
|
||||
if dropout_p != 0.0:
|
||||
raise ValueError("`dropout_p` is not yet supported for flash-attn 3 hub kernels.")
|
||||
if enable_gqa:
|
||||
raise ValueError("`enable_gqa` is not yet supported for flash-attn 3 hub kernels.")
|
||||
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
qv=None,
|
||||
q_descale=None,
|
||||
k_descale=None,
|
||||
v_descale=None,
|
||||
window_size=window_size,
|
||||
softcap=softcap,
|
||||
num_splits=num_splits,
|
||||
pack_gqa=pack_gqa,
|
||||
deterministic=deterministic,
|
||||
sm_margin=sm_margin,
|
||||
return_attn_probs=return_lse,
|
||||
)
|
||||
|
||||
lse = None
|
||||
if return_lse:
|
||||
out, lse = out
|
||||
lse = lse.permute(0, 2, 1).contiguous()
|
||||
|
||||
if _save_ctx:
|
||||
ctx.save_for_backward(query, key, value)
|
||||
ctx.scale = scale
|
||||
ctx.is_causal = is_causal
|
||||
ctx._hub_kernel = func
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
def _flash_attention_3_hub_backward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
grad_out: torch.Tensor,
|
||||
*args,
|
||||
window_size: Tuple[int, int] = (-1, -1),
|
||||
softcap: float = 0.0,
|
||||
num_splits: int = 1,
|
||||
pack_gqa: Optional[bool] = None,
|
||||
deterministic: bool = False,
|
||||
sm_margin: int = 0,
|
||||
):
|
||||
query, key, value = ctx.saved_tensors
|
||||
kernel_fn = ctx._hub_kernel
|
||||
with torch.enable_grad():
|
||||
query_r = query.detach().requires_grad_(True)
|
||||
key_r = key.detach().requires_grad_(True)
|
||||
value_r = value.detach().requires_grad_(True)
|
||||
|
||||
out = kernel_fn(
|
||||
q=query_r,
|
||||
k=key_r,
|
||||
v=value_r,
|
||||
softmax_scale=ctx.scale,
|
||||
causal=ctx.is_causal,
|
||||
qv=None,
|
||||
q_descale=None,
|
||||
k_descale=None,
|
||||
v_descale=None,
|
||||
window_size=window_size,
|
||||
softcap=softcap,
|
||||
num_splits=num_splits,
|
||||
pack_gqa=pack_gqa,
|
||||
deterministic=deterministic,
|
||||
sm_margin=sm_margin,
|
||||
return_attn_probs=False,
|
||||
)
|
||||
if isinstance(out, tuple):
|
||||
out = out[0]
|
||||
|
||||
grad_query, grad_key, grad_value = torch.autograd.grad(
|
||||
out,
|
||||
(query_r, key_r, value_r),
|
||||
grad_out,
|
||||
retain_graph=False,
|
||||
allow_unused=False,
|
||||
)
|
||||
|
||||
return grad_query, grad_key, grad_value
|
||||
|
||||
|
||||
def _sage_attention_forward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
@@ -1357,46 +1107,6 @@ def _sage_attention_forward_op(
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
def _sage_attention_hub_forward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
enable_gqa: bool = False,
|
||||
return_lse: bool = False,
|
||||
_save_ctx: bool = True,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
):
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not yet supported for Sage attention.")
|
||||
if dropout_p > 0.0:
|
||||
raise ValueError("`dropout_p` is not yet supported for Sage attention.")
|
||||
if enable_gqa:
|
||||
raise ValueError("`enable_gqa` is not yet supported for Sage attention.")
|
||||
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
tensor_layout="NHD",
|
||||
is_causal=is_causal,
|
||||
sm_scale=scale,
|
||||
return_lse=return_lse,
|
||||
)
|
||||
|
||||
lse = None
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
lse = lse.permute(0, 2, 1).contiguous()
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
def _sage_attention_backward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
grad_out: torch.Tensor,
|
||||
@@ -2230,7 +1940,7 @@ def _flash_attention(
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.FLASH_HUB,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=True,
|
||||
supports_context_parallel=False,
|
||||
)
|
||||
def _flash_attention_hub(
|
||||
query: torch.Tensor,
|
||||
@@ -2248,35 +1958,17 @@ def _flash_attention_hub(
|
||||
raise ValueError("`attn_mask` is not supported for flash-attn 2.")
|
||||
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn
|
||||
if _parallel_config is None:
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
return_attn_probs=return_lse,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
else:
|
||||
out = _templated_context_parallel_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
None,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
scale,
|
||||
False,
|
||||
return_lse,
|
||||
forward_op=_flash_attention_hub_forward_op,
|
||||
backward_op=_flash_attention_hub_backward_op,
|
||||
_parallel_config=_parallel_config,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse = out
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
return_attn_probs=return_lse,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
@@ -2423,7 +2115,7 @@ def _flash_attention_3(
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName._FLASH_3_HUB,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=True,
|
||||
supports_context_parallel=False,
|
||||
)
|
||||
def _flash_attention_3_hub(
|
||||
query: torch.Tensor,
|
||||
@@ -2438,68 +2130,33 @@ def _flash_attention_3_hub(
|
||||
return_attn_probs: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if _parallel_config:
|
||||
raise NotImplementedError(f"{AttentionBackendName._FLASH_3_HUB.value} is not implemented for parallelism yet.")
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for flash-attn 3.")
|
||||
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
|
||||
if _parallel_config is None:
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
qv=None,
|
||||
q_descale=None,
|
||||
k_descale=None,
|
||||
v_descale=None,
|
||||
window_size=window_size,
|
||||
softcap=softcap,
|
||||
num_splits=1,
|
||||
pack_gqa=None,
|
||||
deterministic=deterministic,
|
||||
sm_margin=0,
|
||||
return_attn_probs=return_attn_probs,
|
||||
)
|
||||
return (out[0], out[1]) if return_attn_probs else out
|
||||
|
||||
forward_op = functools.partial(
|
||||
_flash_attention_3_hub_forward_op,
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
qv=None,
|
||||
q_descale=None,
|
||||
k_descale=None,
|
||||
v_descale=None,
|
||||
window_size=window_size,
|
||||
softcap=softcap,
|
||||
num_splits=1,
|
||||
pack_gqa=None,
|
||||
deterministic=deterministic,
|
||||
sm_margin=0,
|
||||
return_attn_probs=return_attn_probs,
|
||||
)
|
||||
backward_op = functools.partial(
|
||||
_flash_attention_3_hub_backward_op,
|
||||
window_size=window_size,
|
||||
softcap=softcap,
|
||||
num_splits=1,
|
||||
pack_gqa=None,
|
||||
deterministic=deterministic,
|
||||
sm_margin=0,
|
||||
)
|
||||
out = _templated_context_parallel_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
None,
|
||||
0.0,
|
||||
is_causal,
|
||||
scale,
|
||||
False,
|
||||
return_attn_probs,
|
||||
forward_op=forward_op,
|
||||
backward_op=backward_op,
|
||||
_parallel_config=_parallel_config,
|
||||
)
|
||||
if return_attn_probs:
|
||||
out, lse = out
|
||||
return out, lse
|
||||
|
||||
return out
|
||||
# When `return_attn_probs` is True, the above returns a tuple of
|
||||
# actual outputs and lse.
|
||||
return (out[0], out[1]) if return_attn_probs else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
@@ -3130,7 +2787,7 @@ def _sage_attention(
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.SAGE_HUB,
|
||||
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=True,
|
||||
supports_context_parallel=False,
|
||||
)
|
||||
def _sage_attention_hub(
|
||||
query: torch.Tensor,
|
||||
@@ -3158,23 +2815,6 @@ def _sage_attention_hub(
|
||||
)
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
else:
|
||||
out = _templated_context_parallel_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
None,
|
||||
0.0,
|
||||
is_causal,
|
||||
scale,
|
||||
False,
|
||||
return_lse,
|
||||
forward_op=_sage_attention_hub_forward_op,
|
||||
backward_op=_sage_attention_backward_op,
|
||||
_parallel_config=_parallel_config,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse = out
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@@ -1598,11 +1598,11 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
else:
|
||||
logger.warning(f"`blocks` is `None`, no default blocks class found for {self.__class__.__name__}")
|
||||
|
||||
self._blocks = blocks
|
||||
self.blocks = blocks
|
||||
self._components_manager = components_manager
|
||||
self._collection = collection
|
||||
self._component_specs = {spec.name: deepcopy(spec) for spec in self._blocks.expected_components}
|
||||
self._config_specs = {spec.name: deepcopy(spec) for spec in self._blocks.expected_configs}
|
||||
self._component_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_components}
|
||||
self._config_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_configs}
|
||||
|
||||
# update component_specs and config_specs based on modular_model_index.json
|
||||
if modular_config_dict is not None:
|
||||
@@ -1649,9 +1649,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
for name, config_spec in self._config_specs.items():
|
||||
default_configs[name] = config_spec.default
|
||||
self.register_to_config(**default_configs)
|
||||
self.register_to_config(
|
||||
_blocks_class_name=self._blocks.__class__.__name__ if self._blocks is not None else None
|
||||
)
|
||||
self.register_to_config(_blocks_class_name=self.blocks.__class__.__name__ if self.blocks is not None else None)
|
||||
|
||||
@property
|
||||
def default_call_parameters(self) -> Dict[str, Any]:
|
||||
@@ -1660,7 +1658,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
- Dictionary mapping input names to their default values
|
||||
"""
|
||||
params = {}
|
||||
for input_param in self._blocks.inputs:
|
||||
for input_param in self.blocks.inputs:
|
||||
params[input_param.name] = input_param.default
|
||||
return params
|
||||
|
||||
@@ -1831,15 +1829,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
Returns:
|
||||
- The docstring of the pipeline blocks
|
||||
"""
|
||||
return self._blocks.doc
|
||||
|
||||
@property
|
||||
def blocks(self) -> ModularPipelineBlocks:
|
||||
"""
|
||||
Returns:
|
||||
- A copy of the pipeline blocks
|
||||
"""
|
||||
return deepcopy(self._blocks)
|
||||
return self.blocks.doc
|
||||
|
||||
def register_components(self, **kwargs):
|
||||
"""
|
||||
@@ -2575,7 +2565,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
)
|
||||
|
||||
def set_progress_bar_config(self, **kwargs):
|
||||
for sub_block_name, sub_block in self._blocks.sub_blocks.items():
|
||||
for sub_block_name, sub_block in self.blocks.sub_blocks.items():
|
||||
if hasattr(sub_block, "set_progress_bar_config"):
|
||||
sub_block.set_progress_bar_config(**kwargs)
|
||||
|
||||
@@ -2629,7 +2619,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
# Add inputs to state, using defaults if not provided in the kwargs or the state
|
||||
# if same input already in the state, will override it if provided in the kwargs
|
||||
for expected_input_param in self._blocks.inputs:
|
||||
for expected_input_param in self.blocks.inputs:
|
||||
name = expected_input_param.name
|
||||
default = expected_input_param.default
|
||||
kwargs_type = expected_input_param.kwargs_type
|
||||
@@ -2648,9 +2638,9 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
# Run the pipeline
|
||||
with torch.no_grad():
|
||||
try:
|
||||
_, state = self._blocks(self, state)
|
||||
_, state = self.blocks(self, state)
|
||||
except Exception:
|
||||
error_msg = f"Error in block: ({self._blocks.__class__.__name__}):\n"
|
||||
error_msg = f"Error in block: ({self.blocks.__class__.__name__}):\n"
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
|
||||
|
||||
@@ -13,20 +13,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Iterator
|
||||
from fractions import Fraction
|
||||
from itertools import chain
|
||||
from typing import List, Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from ...utils import get_logger, is_av_available
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
from ...utils import is_av_available
|
||||
|
||||
|
||||
_CAN_USE_AV = is_av_available()
|
||||
@@ -109,59 +101,11 @@ def _write_audio(
|
||||
|
||||
|
||||
def encode_video(
|
||||
video: Union[List[PIL.Image.Image], np.ndarray, torch.Tensor, Iterator[torch.Tensor]],
|
||||
fps: int,
|
||||
audio: Optional[torch.Tensor],
|
||||
audio_sample_rate: Optional[int],
|
||||
output_path: str,
|
||||
video_chunks_number: int = 1,
|
||||
video: torch.Tensor, fps: int, audio: Optional[torch.Tensor], audio_sample_rate: Optional[int], output_path: str
|
||||
) -> None:
|
||||
"""
|
||||
Encodes a video with audio using the PyAV library. Based on code from the original LTX-2 repo:
|
||||
https://github.com/Lightricks/LTX-2/blob/4f410820b198e05074a1e92de793e3b59e9ab5a0/packages/ltx-pipelines/src/ltx_pipelines/utils/media_io.py#L182
|
||||
video_np = video.cpu().numpy()
|
||||
|
||||
Args:
|
||||
video (`List[PIL.Image.Image]` or `np.ndarray` or `torch.Tensor`):
|
||||
A video tensor of shape [frames, height, width, channels] with integer pixel values in [0, 255]. If the
|
||||
input is a `np.ndarray`, it is expected to be a float array with values in [0, 1] (which is what pipelines
|
||||
usually return with `output_type="np"`).
|
||||
fps (`int`)
|
||||
The frames per second (FPS) of the encoded video.
|
||||
audio (`torch.Tensor`, *optional*):
|
||||
An audio waveform of shape [audio_channels, samples].
|
||||
audio_sample_rate: (`int`, *optional*):
|
||||
The sampling rate of the audio waveform. For LTX 2, this is typically 24000 (24 kHz).
|
||||
output_path (`str`):
|
||||
The path to save the encoded video to.
|
||||
video_chunks_number (`int`, *optional*, defaults to `1`):
|
||||
The number of chunks to split the video into for encoding. Each chunk will be encoded separately. The
|
||||
number of chunks to use often depends on the tiling config for the video VAE.
|
||||
"""
|
||||
if isinstance(video, list) and isinstance(video[0], PIL.Image.Image):
|
||||
# Pipeline output_type="pil"; assumes each image is in "RGB" mode
|
||||
video_frames = [np.array(frame) for frame in video]
|
||||
video = np.stack(video_frames, axis=0)
|
||||
video = torch.from_numpy(video)
|
||||
elif isinstance(video, np.ndarray):
|
||||
# Pipeline output_type="np"
|
||||
is_denormalized = np.logical_and(np.zeros_like(video) <= video, video <= np.ones_like(video))
|
||||
if np.all(is_denormalized):
|
||||
video = (video * 255).round().astype("uint8")
|
||||
else:
|
||||
logger.warning(
|
||||
"Supplied `numpy.ndarray` does not have values in [0, 1]. The values will be assumed to be pixel "
|
||||
"values in [0, ..., 255] and will be used as is."
|
||||
)
|
||||
video = torch.from_numpy(video)
|
||||
|
||||
if isinstance(video, torch.Tensor):
|
||||
# Split into video_chunks_number along the frame dimension
|
||||
video = torch.tensor_split(video, video_chunks_number, dim=0)
|
||||
video = iter(video)
|
||||
|
||||
first_chunk = next(video)
|
||||
|
||||
_, height, width, _ = first_chunk.shape
|
||||
_, height, width, _ = video_np.shape
|
||||
|
||||
container = av.open(output_path, mode="w")
|
||||
stream = container.add_stream("libx264", rate=int(fps))
|
||||
@@ -175,12 +119,10 @@ def encode_video(
|
||||
|
||||
audio_stream = _prepare_audio_stream(container, audio_sample_rate)
|
||||
|
||||
for video_chunk in tqdm(chain([first_chunk], video), total=video_chunks_number, desc="Encoding video chunks"):
|
||||
video_chunk_cpu = video_chunk.to("cpu").numpy()
|
||||
for frame_array in video_chunk_cpu:
|
||||
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
|
||||
for packet in stream.encode(frame):
|
||||
container.mux(packet)
|
||||
for frame_array in video_np:
|
||||
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
|
||||
for packet in stream.encode(frame):
|
||||
container.mux(packet)
|
||||
|
||||
# Flush encoder
|
||||
for packet in stream.encode():
|
||||
|
||||
@@ -69,6 +69,8 @@ EXAMPLE_DOC_STRING = """
|
||||
... output_type="np",
|
||||
... return_dict=False,
|
||||
... )
|
||||
>>> video = (video * 255).round().astype("uint8")
|
||||
>>> video = torch.from_numpy(video)
|
||||
|
||||
>>> encode_video(
|
||||
... video[0],
|
||||
|
||||
@@ -75,6 +75,8 @@ EXAMPLE_DOC_STRING = """
|
||||
... output_type="np",
|
||||
... return_dict=False,
|
||||
... )
|
||||
>>> video = (video * 255).round().astype("uint8")
|
||||
>>> video = torch.from_numpy(video)
|
||||
|
||||
>>> encode_video(
|
||||
... video[0],
|
||||
|
||||
@@ -76,6 +76,8 @@ EXAMPLE_DOC_STRING = """
|
||||
... output_type="np",
|
||||
... return_dict=False,
|
||||
... )[0]
|
||||
>>> video = (video * 255).round().astype("uint8")
|
||||
>>> video = torch.from_numpy(video)
|
||||
|
||||
>>> encode_video(
|
||||
... video[0],
|
||||
|
||||
Reference in New Issue
Block a user