mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-05 16:20:39 +08:00
Compare commits
10 Commits
custom-dev
...
dynamic-mo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b1df740aac | ||
|
|
8d20369792 | ||
|
|
5910a1cc6c | ||
|
|
40e96454f1 | ||
|
|
47455bd133 | ||
|
|
97c2c6e397 | ||
|
|
212db7b999 | ||
|
|
31058485f1 | ||
|
|
aac94befce | ||
|
|
1f6ac1c3d1 |
@@ -46,6 +46,20 @@ output = pipe(
|
|||||||
output.save("output.png")
|
output.save("output.png")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Cosmos2_5_TransferPipeline
|
||||||
|
|
||||||
|
[[autodoc]] Cosmos2_5_TransferPipeline
|
||||||
|
- all
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
|
||||||
|
## Cosmos2_5_PredictBasePipeline
|
||||||
|
|
||||||
|
[[autodoc]] Cosmos2_5_PredictBasePipeline
|
||||||
|
- all
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
|
||||||
## CosmosTextToWorldPipeline
|
## CosmosTextToWorldPipeline
|
||||||
|
|
||||||
[[autodoc]] CosmosTextToWorldPipeline
|
[[autodoc]] CosmosTextToWorldPipeline
|
||||||
@@ -70,12 +84,6 @@ output.save("output.png")
|
|||||||
- all
|
- all
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
## Cosmos2_5_PredictBasePipeline
|
|
||||||
|
|
||||||
[[autodoc]] Cosmos2_5_PredictBasePipeline
|
|
||||||
- all
|
|
||||||
- __call__
|
|
||||||
|
|
||||||
## CosmosPipelineOutput
|
## CosmosPipelineOutput
|
||||||
|
|
||||||
[[autodoc]] pipelines.cosmos.pipeline_output.CosmosPipelineOutput
|
[[autodoc]] pipelines.cosmos.pipeline_output.CosmosPipelineOutput
|
||||||
|
|||||||
@@ -111,7 +111,7 @@ if __name__ == "__main__":
|
|||||||
Call `torchrun` to run the inference script and use the `--nproc_per_node` argument to set the number of GPUs to use.
|
Call `torchrun` to run the inference script and use the `--nproc_per_node` argument to set the number of GPUs to use.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
torchrun run_distributed.py --nproc_per_node=2
|
torchrun --nproc_per_node=2 run_distributed.py
|
||||||
```
|
```
|
||||||
|
|
||||||
## device_map
|
## device_map
|
||||||
|
|||||||
@@ -94,9 +94,15 @@ python scripts/convert_cosmos_to_diffusers.py \
|
|||||||
--transformer_type Cosmos-2.5-Transfer-General-2B \
|
--transformer_type Cosmos-2.5-Transfer-General-2B \
|
||||||
--transformer_ckpt_path $transformer_ckpt_path \
|
--transformer_ckpt_path $transformer_ckpt_path \
|
||||||
--vae_type wan2.1 \
|
--vae_type wan2.1 \
|
||||||
--output_path converted/transfer/2b/general/depth \
|
--output_path converted/transfer/2b/general/depth/pipeline \
|
||||||
--save_pipeline
|
--save_pipeline
|
||||||
|
|
||||||
|
python scripts/convert_cosmos_to_diffusers.py \
|
||||||
|
--transformer_type Cosmos-2.5-Transfer-General-2B \
|
||||||
|
--transformer_ckpt_path $transformer_ckpt_path \
|
||||||
|
--vae_type wan2.1 \
|
||||||
|
--output_path converted/transfer/2b/general/depth/models
|
||||||
|
|
||||||
# edge
|
# edge
|
||||||
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/edge/61f5694b-0ad5-4ecd-8ad7-c8545627d125_ema_bf16.pt
|
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/edge/61f5694b-0ad5-4ecd-8ad7-c8545627d125_ema_bf16.pt
|
||||||
|
|
||||||
@@ -120,9 +126,15 @@ python scripts/convert_cosmos_to_diffusers.py \
|
|||||||
--transformer_type Cosmos-2.5-Transfer-General-2B \
|
--transformer_type Cosmos-2.5-Transfer-General-2B \
|
||||||
--transformer_ckpt_path $transformer_ckpt_path \
|
--transformer_ckpt_path $transformer_ckpt_path \
|
||||||
--vae_type wan2.1 \
|
--vae_type wan2.1 \
|
||||||
--output_path converted/transfer/2b/general/blur \
|
--output_path converted/transfer/2b/general/blur/pipeline \
|
||||||
--save_pipeline
|
--save_pipeline
|
||||||
|
|
||||||
|
python scripts/convert_cosmos_to_diffusers.py \
|
||||||
|
--transformer_type Cosmos-2.5-Transfer-General-2B \
|
||||||
|
--transformer_ckpt_path $transformer_ckpt_path \
|
||||||
|
--vae_type wan2.1 \
|
||||||
|
--output_path converted/transfer/2b/general/blur/models
|
||||||
|
|
||||||
# seg
|
# seg
|
||||||
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/seg/5136ef49-6d8d-42e8-8abf-7dac722a304a_ema_bf16.pt
|
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/seg/5136ef49-6d8d-42e8-8abf-7dac722a304a_ema_bf16.pt
|
||||||
|
|
||||||
@@ -130,8 +142,14 @@ python scripts/convert_cosmos_to_diffusers.py \
|
|||||||
--transformer_type Cosmos-2.5-Transfer-General-2B \
|
--transformer_type Cosmos-2.5-Transfer-General-2B \
|
||||||
--transformer_ckpt_path $transformer_ckpt_path \
|
--transformer_ckpt_path $transformer_ckpt_path \
|
||||||
--vae_type wan2.1 \
|
--vae_type wan2.1 \
|
||||||
--output_path converted/transfer/2b/general/seg \
|
--output_path converted/transfer/2b/general/seg/pipeline \
|
||||||
--save_pipeline
|
--save_pipeline
|
||||||
|
|
||||||
|
python scripts/convert_cosmos_to_diffusers.py \
|
||||||
|
--transformer_type Cosmos-2.5-Transfer-General-2B \
|
||||||
|
--transformer_ckpt_path $transformer_ckpt_path \
|
||||||
|
--vae_type wan2.1 \
|
||||||
|
--output_path converted/transfer/2b/general/seg/models
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
@@ -856,7 +856,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
|||||||
)
|
)
|
||||||
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("text_encoders.t5xxl.transformer.")}
|
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("text_encoders.t5xxl.transformer.")}
|
||||||
|
|
||||||
has_diffb = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_")) for k in state_dict)
|
has_diffb = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_", "lora_te1_")) for k in state_dict)
|
||||||
if has_diffb:
|
if has_diffb:
|
||||||
zero_status_diff_b = state_dict_all_zero(state_dict, ".diff_b")
|
zero_status_diff_b = state_dict_all_zero(state_dict, ".diff_b")
|
||||||
if zero_status_diff_b:
|
if zero_status_diff_b:
|
||||||
@@ -895,7 +895,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
|||||||
state_dict = {
|
state_dict = {
|
||||||
_custom_replace(k, limit_substrings): v
|
_custom_replace(k, limit_substrings): v
|
||||||
for k, v in state_dict.items()
|
for k, v in state_dict.items()
|
||||||
if k.startswith(("lora_unet_", "lora_te_"))
|
if k.startswith(("lora_unet_", "lora_te_", "lora_te1_"))
|
||||||
}
|
}
|
||||||
|
|
||||||
if any("text_projection" in k for k in state_dict):
|
if any("text_projection" in k for k in state_dict):
|
||||||
|
|||||||
@@ -62,6 +62,8 @@ _REQUIRED_FLEX_VERSION = "2.5.0"
|
|||||||
_REQUIRED_XLA_VERSION = "2.2"
|
_REQUIRED_XLA_VERSION = "2.2"
|
||||||
_REQUIRED_XFORMERS_VERSION = "0.0.29"
|
_REQUIRED_XFORMERS_VERSION = "0.0.29"
|
||||||
|
|
||||||
|
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
|
_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
|
||||||
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
|
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
|
||||||
_CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION)
|
_CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION)
|
||||||
@@ -73,8 +75,18 @@ _CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _
|
|||||||
|
|
||||||
|
|
||||||
if _CAN_USE_FLASH_ATTN:
|
if _CAN_USE_FLASH_ATTN:
|
||||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
try:
|
||||||
from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
|
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||||
|
from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
|
||||||
|
except (ImportError, OSError, RuntimeError) as e:
|
||||||
|
# Handle ABI mismatch or other import failures gracefully.
|
||||||
|
# This can happen when flash_attn was compiled against a different PyTorch version.
|
||||||
|
logger.warning(f"flash_attn is installed but failed to import: {e}. Falling back to native PyTorch attention.")
|
||||||
|
_CAN_USE_FLASH_ATTN = False
|
||||||
|
flash_attn_func = None
|
||||||
|
flash_attn_varlen_func = None
|
||||||
|
_wrapped_flash_attn_backward = None
|
||||||
|
_wrapped_flash_attn_forward = None
|
||||||
else:
|
else:
|
||||||
flash_attn_func = None
|
flash_attn_func = None
|
||||||
flash_attn_varlen_func = None
|
flash_attn_varlen_func = None
|
||||||
@@ -83,26 +95,47 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
if _CAN_USE_FLASH_ATTN_3:
|
if _CAN_USE_FLASH_ATTN_3:
|
||||||
from flash_attn_interface import flash_attn_func as flash_attn_3_func
|
try:
|
||||||
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
|
from flash_attn_interface import flash_attn_func as flash_attn_3_func
|
||||||
|
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
|
||||||
|
except (ImportError, OSError, RuntimeError) as e:
|
||||||
|
logger.warning(f"flash_attn_3 failed to import: {e}. Falling back to native attention.")
|
||||||
|
_CAN_USE_FLASH_ATTN_3 = False
|
||||||
|
flash_attn_3_func = None
|
||||||
|
flash_attn_3_varlen_func = None
|
||||||
else:
|
else:
|
||||||
flash_attn_3_func = None
|
flash_attn_3_func = None
|
||||||
flash_attn_3_varlen_func = None
|
flash_attn_3_varlen_func = None
|
||||||
|
|
||||||
if _CAN_USE_AITER_ATTN:
|
if _CAN_USE_AITER_ATTN:
|
||||||
from aiter import flash_attn_func as aiter_flash_attn_func
|
try:
|
||||||
|
from aiter import flash_attn_func as aiter_flash_attn_func
|
||||||
|
except (ImportError, OSError, RuntimeError) as e:
|
||||||
|
logger.warning(f"aiter failed to import: {e}. Falling back to native attention.")
|
||||||
|
_CAN_USE_AITER_ATTN = False
|
||||||
|
aiter_flash_attn_func = None
|
||||||
else:
|
else:
|
||||||
aiter_flash_attn_func = None
|
aiter_flash_attn_func = None
|
||||||
|
|
||||||
if _CAN_USE_SAGE_ATTN:
|
if _CAN_USE_SAGE_ATTN:
|
||||||
from sageattention import (
|
try:
|
||||||
sageattn,
|
from sageattention import (
|
||||||
sageattn_qk_int8_pv_fp8_cuda,
|
sageattn,
|
||||||
sageattn_qk_int8_pv_fp8_cuda_sm90,
|
sageattn_qk_int8_pv_fp8_cuda,
|
||||||
sageattn_qk_int8_pv_fp16_cuda,
|
sageattn_qk_int8_pv_fp8_cuda_sm90,
|
||||||
sageattn_qk_int8_pv_fp16_triton,
|
sageattn_qk_int8_pv_fp16_cuda,
|
||||||
sageattn_varlen,
|
sageattn_qk_int8_pv_fp16_triton,
|
||||||
)
|
sageattn_varlen,
|
||||||
|
)
|
||||||
|
except (ImportError, OSError, RuntimeError) as e:
|
||||||
|
logger.warning(f"sageattention failed to import: {e}. Falling back to native attention.")
|
||||||
|
_CAN_USE_SAGE_ATTN = False
|
||||||
|
sageattn = None
|
||||||
|
sageattn_qk_int8_pv_fp8_cuda = None
|
||||||
|
sageattn_qk_int8_pv_fp8_cuda_sm90 = None
|
||||||
|
sageattn_qk_int8_pv_fp16_cuda = None
|
||||||
|
sageattn_qk_int8_pv_fp16_triton = None
|
||||||
|
sageattn_varlen = None
|
||||||
else:
|
else:
|
||||||
sageattn = None
|
sageattn = None
|
||||||
sageattn_qk_int8_pv_fp16_cuda = None
|
sageattn_qk_int8_pv_fp16_cuda = None
|
||||||
@@ -113,26 +146,48 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
if _CAN_USE_FLEX_ATTN:
|
if _CAN_USE_FLEX_ATTN:
|
||||||
# We cannot import the flex_attention function from the package directly because it is expected (from the
|
try:
|
||||||
# pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
|
# We cannot import the flex_attention function from the package directly because it is expected (from the
|
||||||
# compiled function.
|
# pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
|
||||||
import torch.nn.attention.flex_attention as flex_attention
|
# compiled function.
|
||||||
|
import torch.nn.attention.flex_attention as flex_attention
|
||||||
|
except (ImportError, OSError, RuntimeError) as e:
|
||||||
|
logger.warning(f"flex_attention failed to import: {e}. Falling back to native attention.")
|
||||||
|
_CAN_USE_FLEX_ATTN = False
|
||||||
|
flex_attention = None
|
||||||
|
else:
|
||||||
|
flex_attention = None
|
||||||
|
|
||||||
|
|
||||||
if _CAN_USE_NPU_ATTN:
|
if _CAN_USE_NPU_ATTN:
|
||||||
from torch_npu import npu_fusion_attention
|
try:
|
||||||
|
from torch_npu import npu_fusion_attention
|
||||||
|
except (ImportError, OSError, RuntimeError) as e:
|
||||||
|
logger.warning(f"torch_npu failed to import: {e}. Falling back to native attention.")
|
||||||
|
_CAN_USE_NPU_ATTN = False
|
||||||
|
npu_fusion_attention = None
|
||||||
else:
|
else:
|
||||||
npu_fusion_attention = None
|
npu_fusion_attention = None
|
||||||
|
|
||||||
|
|
||||||
if _CAN_USE_XLA_ATTN:
|
if _CAN_USE_XLA_ATTN:
|
||||||
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
|
try:
|
||||||
|
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
|
||||||
|
except (ImportError, OSError, RuntimeError) as e:
|
||||||
|
logger.warning(f"torch_xla failed to import: {e}. Falling back to native attention.")
|
||||||
|
_CAN_USE_XLA_ATTN = False
|
||||||
|
xla_flash_attention = None
|
||||||
else:
|
else:
|
||||||
xla_flash_attention = None
|
xla_flash_attention = None
|
||||||
|
|
||||||
|
|
||||||
if _CAN_USE_XFORMERS_ATTN:
|
if _CAN_USE_XFORMERS_ATTN:
|
||||||
import xformers.ops as xops
|
try:
|
||||||
|
import xformers.ops as xops
|
||||||
|
except (ImportError, OSError, RuntimeError) as e:
|
||||||
|
logger.warning(f"xformers failed to import: {e}. Falling back to native attention.")
|
||||||
|
_CAN_USE_XFORMERS_ATTN = False
|
||||||
|
xops = None
|
||||||
else:
|
else:
|
||||||
xops = None
|
xops = None
|
||||||
|
|
||||||
@@ -158,8 +213,6 @@ else:
|
|||||||
_register_fake = register_fake_no_op
|
_register_fake = register_fake_no_op
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
|
||||||
|
|
||||||
# TODO(aryan): Add support for the following:
|
# TODO(aryan): Add support for the following:
|
||||||
# - Sage Attention++
|
# - Sage Attention++
|
||||||
# - block sparse, radial and other attention methods
|
# - block sparse, radial and other attention methods
|
||||||
@@ -276,7 +329,11 @@ class _HubKernelConfig:
|
|||||||
_HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = {
|
_HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = {
|
||||||
# TODO: temporary revision for now. Remove when merged upstream into `main`.
|
# TODO: temporary revision for now. Remove when merged upstream into `main`.
|
||||||
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
|
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
|
||||||
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
|
repo_id="kernels-community/flash-attn3",
|
||||||
|
function_attr="flash_attn_func",
|
||||||
|
revision="fake-ops-return-probs",
|
||||||
|
wrapped_forward_attr="flash_attn_interface._flash_attn_forward",
|
||||||
|
wrapped_backward_attr="flash_attn_interface._flash_attn_backward",
|
||||||
),
|
),
|
||||||
AttentionBackendName._FLASH_3_VARLEN_HUB: _HubKernelConfig(
|
AttentionBackendName._FLASH_3_VARLEN_HUB: _HubKernelConfig(
|
||||||
repo_id="kernels-community/flash-attn3",
|
repo_id="kernels-community/flash-attn3",
|
||||||
@@ -676,7 +733,7 @@ def _wrapped_flash_attn_3(
|
|||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Hardcoded for now because pytorch does not support tuple/int type hints
|
# Hardcoded for now because pytorch does not support tuple/int type hints
|
||||||
window_size = (-1, -1)
|
window_size = (-1, -1)
|
||||||
out, lse, *_ = flash_attn_3_func(
|
result = flash_attn_3_func(
|
||||||
q=q,
|
q=q,
|
||||||
k=k,
|
k=k,
|
||||||
v=v,
|
v=v,
|
||||||
@@ -693,7 +750,9 @@ def _wrapped_flash_attn_3(
|
|||||||
pack_gqa=pack_gqa,
|
pack_gqa=pack_gqa,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
sm_margin=sm_margin,
|
sm_margin=sm_margin,
|
||||||
|
return_attn_probs=True,
|
||||||
)
|
)
|
||||||
|
out, lse, *_ = result
|
||||||
lse = lse.permute(0, 2, 1)
|
lse = lse.permute(0, 2, 1)
|
||||||
return out, lse
|
return out, lse
|
||||||
|
|
||||||
@@ -1237,36 +1296,62 @@ def _flash_attention_3_hub_forward_op(
|
|||||||
if enable_gqa:
|
if enable_gqa:
|
||||||
raise ValueError("`enable_gqa` is not yet supported for flash-attn 3 hub kernels.")
|
raise ValueError("`enable_gqa` is not yet supported for flash-attn 3 hub kernels.")
|
||||||
|
|
||||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
|
config = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB]
|
||||||
out = func(
|
wrapped_forward_fn = config.wrapped_forward_fn
|
||||||
q=query,
|
if wrapped_forward_fn is None:
|
||||||
k=key,
|
raise RuntimeError(
|
||||||
v=value,
|
"Flash attention 3 hub kernels must expose `flash_attn_interface._flash_attn_forward` "
|
||||||
softmax_scale=scale,
|
"for context parallel execution."
|
||||||
|
)
|
||||||
|
|
||||||
|
if scale is None:
|
||||||
|
scale = query.shape[-1] ** (-0.5)
|
||||||
|
|
||||||
|
out, softmax_lse, *_ = wrapped_forward_fn(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
None,
|
||||||
|
None, # k_new, v_new
|
||||||
|
None, # qv
|
||||||
|
None, # out
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None, # cu_seqlens_q/k/k_new
|
||||||
|
None,
|
||||||
|
None, # seqused_q/k
|
||||||
|
None,
|
||||||
|
None, # max_seqlen_q/k
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None, # page_table, kv_batch_idx, leftpad_k
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None, # rotary_cos/sin, seqlens_rotary
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None, # q_descale, k_descale, v_descale
|
||||||
|
scale,
|
||||||
causal=is_causal,
|
causal=is_causal,
|
||||||
qv=None,
|
window_size_left=window_size[0],
|
||||||
q_descale=None,
|
window_size_right=window_size[1],
|
||||||
k_descale=None,
|
attention_chunk=0,
|
||||||
v_descale=None,
|
|
||||||
window_size=window_size,
|
|
||||||
softcap=softcap,
|
softcap=softcap,
|
||||||
num_splits=num_splits,
|
num_splits=num_splits,
|
||||||
pack_gqa=pack_gqa,
|
pack_gqa=pack_gqa,
|
||||||
deterministic=deterministic,
|
|
||||||
sm_margin=sm_margin,
|
sm_margin=sm_margin,
|
||||||
return_attn_probs=return_lse,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
lse = None
|
lse = softmax_lse.permute(0, 2, 1).contiguous() if return_lse else None
|
||||||
if return_lse:
|
|
||||||
out, lse = out
|
|
||||||
lse = lse.permute(0, 2, 1).contiguous()
|
|
||||||
|
|
||||||
if _save_ctx:
|
if _save_ctx:
|
||||||
ctx.save_for_backward(query, key, value)
|
ctx.save_for_backward(query, key, value, out, softmax_lse)
|
||||||
ctx.scale = scale
|
ctx.scale = scale
|
||||||
ctx.is_causal = is_causal
|
ctx.is_causal = is_causal
|
||||||
ctx._hub_kernel = func
|
ctx.window_size = window_size
|
||||||
|
ctx.softcap = softcap
|
||||||
|
ctx.deterministic = deterministic
|
||||||
|
ctx.sm_margin = sm_margin
|
||||||
|
|
||||||
return (out, lse) if return_lse else out
|
return (out, lse) if return_lse else out
|
||||||
|
|
||||||
@@ -1275,54 +1360,49 @@ def _flash_attention_3_hub_backward_op(
|
|||||||
ctx: torch.autograd.function.FunctionCtx,
|
ctx: torch.autograd.function.FunctionCtx,
|
||||||
grad_out: torch.Tensor,
|
grad_out: torch.Tensor,
|
||||||
*args,
|
*args,
|
||||||
window_size: tuple[int, int] = (-1, -1),
|
**kwargs,
|
||||||
softcap: float = 0.0,
|
|
||||||
num_splits: int = 1,
|
|
||||||
pack_gqa: bool | None = None,
|
|
||||||
deterministic: bool = False,
|
|
||||||
sm_margin: int = 0,
|
|
||||||
):
|
):
|
||||||
query, key, value = ctx.saved_tensors
|
config = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB]
|
||||||
kernel_fn = ctx._hub_kernel
|
wrapped_backward_fn = config.wrapped_backward_fn
|
||||||
# NOTE: Unlike the FA2 hub kernel, the FA3 hub kernel does not expose separate wrapped forward/backward
|
if wrapped_backward_fn is None:
|
||||||
# primitives (no `wrapped_forward_attr`/`wrapped_backward_attr` in its `_HubKernelConfig`). We
|
raise RuntimeError(
|
||||||
# therefore rerun the forward pass under `torch.enable_grad()` and differentiate through it with
|
"Flash attention 3 hub kernels must expose `flash_attn_interface._flash_attn_backward` "
|
||||||
# `torch.autograd.grad()`. This is a second forward pass during backward; it can be avoided once
|
"for context parallel execution."
|
||||||
# the FA3 hub exposes a dedicated fused backward kernel (analogous to `_wrapped_flash_attn_backward`
|
|
||||||
# in the FA2 hub), at which point this can be refactored to match `_flash_attention_hub_backward_op`.
|
|
||||||
with torch.enable_grad():
|
|
||||||
query_r = query.detach().requires_grad_(True)
|
|
||||||
key_r = key.detach().requires_grad_(True)
|
|
||||||
value_r = value.detach().requires_grad_(True)
|
|
||||||
|
|
||||||
out = kernel_fn(
|
|
||||||
q=query_r,
|
|
||||||
k=key_r,
|
|
||||||
v=value_r,
|
|
||||||
softmax_scale=ctx.scale,
|
|
||||||
causal=ctx.is_causal,
|
|
||||||
qv=None,
|
|
||||||
q_descale=None,
|
|
||||||
k_descale=None,
|
|
||||||
v_descale=None,
|
|
||||||
window_size=window_size,
|
|
||||||
softcap=softcap,
|
|
||||||
num_splits=num_splits,
|
|
||||||
pack_gqa=pack_gqa,
|
|
||||||
deterministic=deterministic,
|
|
||||||
sm_margin=sm_margin,
|
|
||||||
return_attn_probs=False,
|
|
||||||
)
|
)
|
||||||
if isinstance(out, tuple):
|
|
||||||
out = out[0]
|
|
||||||
|
|
||||||
grad_query, grad_key, grad_value = torch.autograd.grad(
|
query, key, value, out, softmax_lse = ctx.saved_tensors
|
||||||
out,
|
grad_query = torch.empty_like(query)
|
||||||
(query_r, key_r, value_r),
|
grad_key = torch.empty_like(key)
|
||||||
grad_out,
|
grad_value = torch.empty_like(value)
|
||||||
retain_graph=False,
|
|
||||||
allow_unused=False,
|
wrapped_backward_fn(
|
||||||
)
|
grad_out,
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
out,
|
||||||
|
softmax_lse,
|
||||||
|
None,
|
||||||
|
None, # cu_seqlens_q, cu_seqlens_k
|
||||||
|
None,
|
||||||
|
None, # seqused_q, seqused_k
|
||||||
|
None,
|
||||||
|
None, # max_seqlen_q, max_seqlen_k
|
||||||
|
grad_query,
|
||||||
|
grad_key,
|
||||||
|
grad_value,
|
||||||
|
ctx.scale,
|
||||||
|
ctx.is_causal,
|
||||||
|
ctx.window_size[0],
|
||||||
|
ctx.window_size[1],
|
||||||
|
ctx.softcap,
|
||||||
|
ctx.deterministic,
|
||||||
|
ctx.sm_margin,
|
||||||
|
)
|
||||||
|
|
||||||
|
grad_query = grad_query[..., : grad_out.shape[-1]]
|
||||||
|
grad_key = grad_key[..., : grad_out.shape[-1]]
|
||||||
|
grad_value = grad_value[..., : grad_out.shape[-1]]
|
||||||
|
|
||||||
return grad_query, grad_key, grad_value
|
return grad_query, grad_key, grad_value
|
||||||
|
|
||||||
@@ -2623,7 +2703,7 @@ def _flash_varlen_attention_3(
|
|||||||
key_packed = torch.cat(key_valid, dim=0)
|
key_packed = torch.cat(key_valid, dim=0)
|
||||||
value_packed = torch.cat(value_valid, dim=0)
|
value_packed = torch.cat(value_valid, dim=0)
|
||||||
|
|
||||||
out, lse, *_ = flash_attn_3_varlen_func(
|
result = flash_attn_3_varlen_func(
|
||||||
q=query_packed,
|
q=query_packed,
|
||||||
k=key_packed,
|
k=key_packed,
|
||||||
v=value_packed,
|
v=value_packed,
|
||||||
@@ -2633,7 +2713,13 @@ def _flash_varlen_attention_3(
|
|||||||
max_seqlen_k=max_seqlen_k,
|
max_seqlen_k=max_seqlen_k,
|
||||||
softmax_scale=scale,
|
softmax_scale=scale,
|
||||||
causal=is_causal,
|
causal=is_causal,
|
||||||
|
return_attn_probs=return_lse,
|
||||||
)
|
)
|
||||||
|
if isinstance(result, tuple):
|
||||||
|
out, lse, *_ = result
|
||||||
|
else:
|
||||||
|
out = result
|
||||||
|
lse = None
|
||||||
out = out.unflatten(0, (batch_size, -1))
|
out = out.unflatten(0, (batch_size, -1))
|
||||||
|
|
||||||
return (out, lse) if return_lse else out
|
return (out, lse) if return_lse else out
|
||||||
|
|||||||
@@ -191,7 +191,12 @@ class CosmosControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|||||||
dim=1,
|
dim=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
control_hidden_states = torch.cat([control_hidden_states, torch.zeros_like(controls_latents[:, :1])], dim=1)
|
if condition_mask is not None:
|
||||||
|
control_hidden_states = torch.cat([control_hidden_states, condition_mask], dim=1)
|
||||||
|
else:
|
||||||
|
control_hidden_states = torch.cat(
|
||||||
|
[control_hidden_states, torch.zeros_like(controls_latents[:, :1])], dim=1
|
||||||
|
)
|
||||||
|
|
||||||
padding_mask_resized = transforms.functional.resize(
|
padding_mask_resized = transforms.functional.resize(
|
||||||
padding_mask, list(control_hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
|
padding_mask, list(control_hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
|
||||||
|
|||||||
@@ -17,9 +17,6 @@ from typing import Callable, Dict, List, Optional, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
|
||||||
import torchvision.transforms
|
|
||||||
import torchvision.transforms.functional
|
|
||||||
from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration
|
from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration
|
||||||
|
|
||||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||||
@@ -54,11 +51,13 @@ else:
|
|||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
def _maybe_pad_video(video: torch.Tensor, num_frames: int):
|
def _maybe_pad_or_trim_video(video: torch.Tensor, num_frames: int):
|
||||||
n_pad_frames = num_frames - video.shape[2]
|
n_pad_frames = num_frames - video.shape[2]
|
||||||
if n_pad_frames > 0:
|
if n_pad_frames > 0:
|
||||||
last_frame = video[:, :, -1:, :, :]
|
last_frame = video[:, :, -1:, :, :]
|
||||||
video = torch.cat((video, last_frame.repeat(1, 1, n_pad_frames, 1, 1)), dim=2)
|
video = torch.cat((video, last_frame.repeat(1, 1, n_pad_frames, 1, 1)), dim=2)
|
||||||
|
elif num_frames < video.shape[2]:
|
||||||
|
video = video[:, :, :num_frames, :, :]
|
||||||
return video
|
return video
|
||||||
|
|
||||||
|
|
||||||
@@ -134,8 +133,8 @@ EXAMPLE_DOC_STRING = """
|
|||||||
>>> controls = [Image.fromarray(x.numpy()) for x in controls.permute(1, 2, 3, 0)]
|
>>> controls = [Image.fromarray(x.numpy()) for x in controls.permute(1, 2, 3, 0)]
|
||||||
>>> export_to_video(controls, "edge_controlled_video_edge.mp4", fps=30)
|
>>> export_to_video(controls, "edge_controlled_video_edge.mp4", fps=30)
|
||||||
|
|
||||||
|
>>> # Transfer inference with controls.
|
||||||
>>> video = pipe(
|
>>> video = pipe(
|
||||||
... video=input_video[:num_frames],
|
|
||||||
... controls=controls,
|
... controls=controls,
|
||||||
... controls_conditioning_scale=1.0,
|
... controls_conditioning_scale=1.0,
|
||||||
... prompt=prompt,
|
... prompt=prompt,
|
||||||
@@ -149,7 +148,7 @@ EXAMPLE_DOC_STRING = """
|
|||||||
|
|
||||||
class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||||
r"""
|
r"""
|
||||||
Pipeline for Cosmos Transfer2.5 base model.
|
Pipeline for Cosmos Transfer2.5, supporting auto-regressive inference.
|
||||||
|
|
||||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||||
@@ -166,12 +165,14 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
|||||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||||
vae ([`AutoencoderKLWan`]):
|
vae ([`AutoencoderKLWan`]):
|
||||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||||
|
controlnet ([`CosmosControlNetModel`]):
|
||||||
|
ControlNet used to condition generation on control inputs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_cpu_offload_seq = "text_encoder->transformer->controlnet->vae"
|
model_cpu_offload_seq = "text_encoder->transformer->controlnet->vae"
|
||||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||||
# We mark safety_checker as optional here to get around some test failures, but it is not really optional
|
# We mark safety_checker as optional here to get around some test failures, but it is not really optional
|
||||||
_optional_components = ["safety_checker", "controlnet"]
|
_optional_components = ["safety_checker"]
|
||||||
_exclude_from_cpu_offload = ["safety_checker"]
|
_exclude_from_cpu_offload = ["safety_checker"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -181,8 +182,8 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
|||||||
transformer: CosmosTransformer3DModel,
|
transformer: CosmosTransformer3DModel,
|
||||||
vae: AutoencoderKLWan,
|
vae: AutoencoderKLWan,
|
||||||
scheduler: UniPCMultistepScheduler,
|
scheduler: UniPCMultistepScheduler,
|
||||||
controlnet: Optional[CosmosControlNetModel],
|
controlnet: CosmosControlNetModel,
|
||||||
safety_checker: CosmosSafetyChecker = None,
|
safety_checker: Optional[CosmosSafetyChecker] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -384,10 +385,11 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
|||||||
num_frames_in: int = 93,
|
num_frames_in: int = 93,
|
||||||
num_frames_out: int = 93,
|
num_frames_out: int = 93,
|
||||||
do_classifier_free_guidance: bool = True,
|
do_classifier_free_guidance: bool = True,
|
||||||
dtype: torch.dtype | None = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
device: torch.device | None = None,
|
device: Optional[torch.device] = None,
|
||||||
generator: torch.Generator | list[torch.Generator] | None = None,
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||||
latents: torch.Tensor | None = None,
|
latents: Optional[torch.Tensor] = None,
|
||||||
|
num_cond_latent_frames: int = 0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if isinstance(generator, list) and len(generator) != batch_size:
|
if isinstance(generator, list) and len(generator) != batch_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -402,10 +404,14 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
|||||||
W = width // self.vae_scale_factor_spatial
|
W = width // self.vae_scale_factor_spatial
|
||||||
shape = (B, C, T, H, W)
|
shape = (B, C, T, H, W)
|
||||||
|
|
||||||
if num_frames_in == 0:
|
if latents is not None:
|
||||||
if latents is None:
|
if latents.shape[1:] != shape[1:]:
|
||||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
raise ValueError(f"Unexpected `latents` shape, got {latents.shape}, expected {shape}.")
|
||||||
|
latents = latents.to(device=device, dtype=dtype)
|
||||||
|
else:
|
||||||
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
if num_frames_in == 0:
|
||||||
cond_mask = torch.zeros((B, 1, T, H, W), dtype=latents.dtype, device=latents.device)
|
cond_mask = torch.zeros((B, 1, T, H, W), dtype=latents.dtype, device=latents.device)
|
||||||
cond_indicator = torch.zeros((B, 1, T, 1, 1), dtype=latents.dtype, device=latents.device)
|
cond_indicator = torch.zeros((B, 1, T, 1, 1), dtype=latents.dtype, device=latents.device)
|
||||||
|
|
||||||
@@ -435,16 +441,12 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
|||||||
latents_std = self.latents_std.to(device=device, dtype=dtype)
|
latents_std = self.latents_std.to(device=device, dtype=dtype)
|
||||||
cond_latents = (cond_latents - latents_mean) / latents_std
|
cond_latents = (cond_latents - latents_mean) / latents_std
|
||||||
|
|
||||||
if latents is None:
|
|
||||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
|
||||||
else:
|
|
||||||
latents = latents.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
padding_shape = (B, 1, T, H, W)
|
padding_shape = (B, 1, T, H, W)
|
||||||
ones_padding = latents.new_ones(padding_shape)
|
ones_padding = latents.new_ones(padding_shape)
|
||||||
zeros_padding = latents.new_zeros(padding_shape)
|
zeros_padding = latents.new_zeros(padding_shape)
|
||||||
|
|
||||||
cond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1)
|
cond_indicator = latents.new_zeros(B, 1, latents.size(2), 1, 1)
|
||||||
|
cond_indicator[:, :, 0:num_cond_latent_frames, :, :] = 1.0
|
||||||
cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding
|
cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@@ -454,34 +456,7 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
|||||||
cond_indicator,
|
cond_indicator,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _encode_controls(
|
# Modified from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs
|
||||||
self,
|
|
||||||
controls: Optional[torch.Tensor],
|
|
||||||
height: int,
|
|
||||||
width: int,
|
|
||||||
num_frames: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
device: torch.device,
|
|
||||||
generator: torch.Generator | list[torch.Generator] | None,
|
|
||||||
) -> Optional[torch.Tensor]:
|
|
||||||
if controls is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
control_video = self.video_processor.preprocess_video(controls, height, width)
|
|
||||||
control_video = _maybe_pad_video(control_video, num_frames)
|
|
||||||
|
|
||||||
control_video = control_video.to(device=device, dtype=self.vae.dtype)
|
|
||||||
control_latents = [
|
|
||||||
retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator=generator) for vid in control_video
|
|
||||||
]
|
|
||||||
control_latents = torch.cat(control_latents, dim=0).to(dtype)
|
|
||||||
|
|
||||||
latents_mean = self.latents_mean.to(device=device, dtype=dtype)
|
|
||||||
latents_std = self.latents_std.to(device=device, dtype=dtype)
|
|
||||||
control_latents = (control_latents - latents_mean) / latents_std
|
|
||||||
return control_latents
|
|
||||||
|
|
||||||
# Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs
|
|
||||||
def check_inputs(
|
def check_inputs(
|
||||||
self,
|
self,
|
||||||
prompt,
|
prompt,
|
||||||
@@ -489,9 +464,25 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
|||||||
width,
|
width,
|
||||||
prompt_embeds=None,
|
prompt_embeds=None,
|
||||||
callback_on_step_end_tensor_inputs=None,
|
callback_on_step_end_tensor_inputs=None,
|
||||||
|
num_ar_conditional_frames=None,
|
||||||
|
num_ar_latent_conditional_frames=None,
|
||||||
|
num_frames_per_chunk=None,
|
||||||
|
num_frames=None,
|
||||||
|
conditional_frame_timestep=0.1,
|
||||||
):
|
):
|
||||||
if height % 16 != 0 or width % 16 != 0:
|
if width <= 0 or height <= 0 or height % 16 != 0 or width % 16 != 0:
|
||||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
raise ValueError(
|
||||||
|
f"`height` and `width` have to be divisible by 16 (& positive) but are {height} and {width}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if num_frames is not None and num_frames <= 0:
|
||||||
|
raise ValueError(f"`num_frames` has to be a positive integer when provided but is {num_frames}.")
|
||||||
|
|
||||||
|
if conditional_frame_timestep < 0 or conditional_frame_timestep > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"`conditional_frame_timestep` has to be a float in the [0, 1] interval but is "
|
||||||
|
f"{conditional_frame_timestep}."
|
||||||
|
)
|
||||||
|
|
||||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||||
@@ -512,6 +503,46 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
|||||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||||
|
|
||||||
|
if num_ar_latent_conditional_frames is not None and num_ar_conditional_frames is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Provide only one of `num_ar_conditional_frames` or `num_ar_latent_conditional_frames`, not both."
|
||||||
|
)
|
||||||
|
if num_ar_latent_conditional_frames is None and num_ar_conditional_frames is None:
|
||||||
|
raise ValueError("Provide either `num_ar_conditional_frames` or `num_ar_latent_conditional_frames`.")
|
||||||
|
if num_ar_latent_conditional_frames is not None and num_ar_latent_conditional_frames < 0:
|
||||||
|
raise ValueError("`num_ar_latent_conditional_frames` must be >= 0.")
|
||||||
|
if num_ar_conditional_frames is not None and num_ar_conditional_frames < 0:
|
||||||
|
raise ValueError("`num_ar_conditional_frames` must be >= 0.")
|
||||||
|
|
||||||
|
if num_ar_latent_conditional_frames is not None:
|
||||||
|
num_ar_conditional_frames = max(
|
||||||
|
0, (num_ar_latent_conditional_frames - 1) * self.vae_scale_factor_temporal + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
min_chunk_len = self.vae_scale_factor_temporal + 1
|
||||||
|
if num_frames_per_chunk < min_chunk_len:
|
||||||
|
logger.warning(f"{num_frames_per_chunk=} must be larger than {min_chunk_len=}, setting to min_chunk_len")
|
||||||
|
num_frames_per_chunk = min_chunk_len
|
||||||
|
|
||||||
|
max_frames_by_rope = None
|
||||||
|
if getattr(self.transformer.config, "max_size", None) is not None:
|
||||||
|
max_frames_by_rope = max(
|
||||||
|
size // patch
|
||||||
|
for size, patch in zip(self.transformer.config.max_size, self.transformer.config.patch_size)
|
||||||
|
)
|
||||||
|
if num_frames_per_chunk > max_frames_by_rope:
|
||||||
|
raise ValueError(
|
||||||
|
f"{num_frames_per_chunk=} is too large for RoPE setting ({max_frames_by_rope=}). "
|
||||||
|
"Please reduce `num_frames_per_chunk`."
|
||||||
|
)
|
||||||
|
|
||||||
|
if num_ar_conditional_frames >= num_frames_per_chunk:
|
||||||
|
raise ValueError(
|
||||||
|
f"{num_ar_conditional_frames=} must be smaller than {num_frames_per_chunk=} for chunked generation."
|
||||||
|
)
|
||||||
|
|
||||||
|
return num_frames_per_chunk
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def guidance_scale(self):
|
def guidance_scale(self):
|
||||||
return self._guidance_scale
|
return self._guidance_scale
|
||||||
@@ -536,23 +567,22 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
|||||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
image: PipelineImageInput | None = None,
|
controls: PipelineImageInput | List[PipelineImageInput],
|
||||||
video: List[PipelineImageInput] | None = None,
|
controls_conditioning_scale: Union[float, List[float]] = 1.0,
|
||||||
prompt: Union[str, List[str]] | None = None,
|
prompt: Union[str, List[str]] | None = None,
|
||||||
negative_prompt: Union[str, List[str]] = DEFAULT_NEGATIVE_PROMPT,
|
negative_prompt: Union[str, List[str]] = DEFAULT_NEGATIVE_PROMPT,
|
||||||
height: int = 704,
|
height: int = 704,
|
||||||
width: int | None = None,
|
width: Optional[int] = None,
|
||||||
num_frames: int = 93,
|
num_frames: Optional[int] = None,
|
||||||
|
num_frames_per_chunk: int = 93,
|
||||||
num_inference_steps: int = 36,
|
num_inference_steps: int = 36,
|
||||||
guidance_scale: float = 3.0,
|
guidance_scale: float = 3.0,
|
||||||
num_videos_per_prompt: Optional[int] = 1,
|
num_videos_per_prompt: int = 1,
|
||||||
generator: torch.Generator | list[torch.Generator] | None = None,
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||||
latents: torch.Tensor | None = None,
|
latents: Optional[torch.Tensor] = None,
|
||||||
controls: Optional[PipelineImageInput | List[PipelineImageInput]] = None,
|
prompt_embeds: Optional[torch.Tensor] = None,
|
||||||
controls_conditioning_scale: float | list[float] = 1.0,
|
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||||
prompt_embeds: torch.Tensor | None = None,
|
output_type: Optional[str] = "pil",
|
||||||
negative_prompt_embeds: torch.Tensor | None = None,
|
|
||||||
output_type: str = "pil",
|
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
callback_on_step_end: Optional[
|
callback_on_step_end: Optional[
|
||||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||||
@@ -560,24 +590,26 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
|||||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||||
max_sequence_length: int = 512,
|
max_sequence_length: int = 512,
|
||||||
conditional_frame_timestep: float = 0.1,
|
conditional_frame_timestep: float = 0.1,
|
||||||
|
num_ar_conditional_frames: Optional[int] = 1,
|
||||||
|
num_ar_latent_conditional_frames: Optional[int] = None,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
The call function to the pipeline for generation. Supports three modes:
|
`controls` drive the conditioning through ControlNet. Controls are assumed to be pre-processed, e.g. edge maps
|
||||||
|
are pre-computed.
|
||||||
|
|
||||||
- **Text2World**: `image=None`, `video=None`, `prompt` provided. Generates a world clip.
|
Setting `num_frames` will restrict the total number of frames output, if not provided or assigned to None
|
||||||
- **Image2World**: `image` provided, `video=None`, `prompt` provided. Conditions on a single frame.
|
(default) then the number of output frames will match the input `controls`.
|
||||||
- **Video2World**: `video` provided, `image=None`, `prompt` provided. Conditions on an input clip.
|
|
||||||
|
|
||||||
Set `num_frames=93` (default) to produce a world video, or `num_frames=1` to produce a single image frame (the
|
Auto-regressive inference is supported and thus a sliding window of `num_frames_per_chunk` frames are used per
|
||||||
above in "*2Image mode").
|
denoising loop. In addition, when auto-regressive inference is performed, the previous
|
||||||
|
`num_ar_latent_conditional_frames` or `num_ar_conditional_frames` are used to condition the following denoising
|
||||||
Outputs follow `output_type` (e.g., `"pil"` returns a list of `num_frames` PIL images per prompt).
|
inference loops.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, *optional*):
|
controls (`PipelineImageInput`, `List[PipelineImageInput]`):
|
||||||
Optional single image for Image2World conditioning. Must be `None` when `video` is provided.
|
Control image or video input used by the ControlNet.
|
||||||
video (`List[PIL.Image.Image]`, `np.ndarray`, `torch.Tensor`, *optional*):
|
controls_conditioning_scale (`float` or `List[float]`, *optional*, defaults to `1.0`):
|
||||||
Optional input video for Video2World conditioning. Must be `None` when `image` is provided.
|
The scale factor(s) for the ControlNet outputs. A single float is broadcast to all control blocks.
|
||||||
prompt (`str` or `List[str]`, *optional*):
|
prompt (`str` or `List[str]`, *optional*):
|
||||||
The prompt or prompts to guide generation. Required unless `prompt_embeds` is supplied.
|
The prompt or prompts to guide generation. Required unless `prompt_embeds` is supplied.
|
||||||
height (`int`, defaults to `704`):
|
height (`int`, defaults to `704`):
|
||||||
@@ -585,9 +617,10 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
|||||||
width (`int`, *optional*):
|
width (`int`, *optional*):
|
||||||
The width in pixels of the generated image. If not provided, this will be determined based on the
|
The width in pixels of the generated image. If not provided, this will be determined based on the
|
||||||
aspect ratio of the input and the provided height.
|
aspect ratio of the input and the provided height.
|
||||||
num_frames (`int`, defaults to `93`):
|
num_frames (`int`, *optional*):
|
||||||
Number of output frames. Use `93` for world (video) generation; set to `1` to return a single frame.
|
Number of output frames. Defaults to `None` to output the same number of frames as the input
|
||||||
num_inference_steps (`int`, defaults to `35`):
|
`controls`.
|
||||||
|
num_inference_steps (`int`, defaults to `36`):
|
||||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||||
expense of slower inference.
|
expense of slower inference.
|
||||||
guidance_scale (`float`, defaults to `3.0`):
|
guidance_scale (`float`, defaults to `3.0`):
|
||||||
@@ -601,13 +634,9 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
|||||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||||
generation deterministic.
|
generation deterministic.
|
||||||
latents (`torch.Tensor`, *optional*):
|
latents (`torch.Tensor`, *optional*):
|
||||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs. Can be used to
|
||||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
tweak the same generation with different prompts. If not provided, a latents tensor is generated by
|
||||||
tensor is generated by sampling using the supplied random `generator`.
|
sampling using the supplied random `generator`.
|
||||||
controls (`PipelineImageInput`, `List[PipelineImageInput]`, *optional*):
|
|
||||||
Control image or video input used by the ControlNet. If `None`, ControlNet is skipped.
|
|
||||||
controls_conditioning_scale (`float` or `List[float]`, *optional*, defaults to `1.0`):
|
|
||||||
The scale factor(s) for the ControlNet outputs. A single float is broadcast to all control blocks.
|
|
||||||
prompt_embeds (`torch.Tensor`, *optional*):
|
prompt_embeds (`torch.Tensor`, *optional*):
|
||||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||||
provided, text embeddings will be generated from `prompt` input argument.
|
provided, text embeddings will be generated from `prompt` input argument.
|
||||||
@@ -630,7 +659,18 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
|||||||
max_sequence_length (`int`, defaults to `512`):
|
max_sequence_length (`int`, defaults to `512`):
|
||||||
The maximum number of tokens in the prompt. If the prompt exceeds this length, it will be truncated. If
|
The maximum number of tokens in the prompt. If the prompt exceeds this length, it will be truncated. If
|
||||||
the prompt is shorter than this length, it will be padded.
|
the prompt is shorter than this length, it will be padded.
|
||||||
|
num_ar_conditional_frames (`int`, *optional*, defaults to `1`):
|
||||||
|
Number of frames to condition on subsequent inference loops in auto-regressive inference, i.e. for the
|
||||||
|
second chunk and onwards. Only used if `num_ar_latent_conditional_frames` is `None`.
|
||||||
|
|
||||||
|
This is only used when auto-regressive inference is performed, i.e. when the number of frames in
|
||||||
|
controls is > num_frames_per_chunk
|
||||||
|
num_ar_latent_conditional_frames (`int`, *optional*):
|
||||||
|
Number of latent frames to condition on subsequent inference loops in auto-regressive inference, i.e.
|
||||||
|
for the second chunk and onwards. Only used if `num_ar_conditional_frames` is `None`.
|
||||||
|
|
||||||
|
This is only used when auto-regressive inference is performed, i.e. when the number of frames in
|
||||||
|
controls is > num_frames_per_chunk
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -650,21 +690,40 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
|||||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||||
|
|
||||||
if width is None:
|
if width is None:
|
||||||
frame = image or video[0] if image or video else None
|
frame = controls[0] if isinstance(controls, list) else controls
|
||||||
if frame is None and controls is not None:
|
if isinstance(frame, list):
|
||||||
frame = controls[0] if isinstance(controls, list) else controls
|
frame = frame[0]
|
||||||
if isinstance(frame, (torch.Tensor, np.ndarray)) and len(frame.shape) == 4:
|
if isinstance(frame, (torch.Tensor, np.ndarray)):
|
||||||
frame = controls[0]
|
if frame.ndim == 5:
|
||||||
|
frame = frame[0, 0]
|
||||||
|
elif frame.ndim == 4:
|
||||||
|
frame = frame[0]
|
||||||
|
|
||||||
if frame is None:
|
if isinstance(frame, PIL.Image.Image):
|
||||||
width = int((height + 16) * (1280 / 720))
|
|
||||||
elif isinstance(frame, PIL.Image.Image):
|
|
||||||
width = int((height + 16) * (frame.width / frame.height))
|
width = int((height + 16) * (frame.width / frame.height))
|
||||||
else:
|
else:
|
||||||
|
if frame.ndim != 3:
|
||||||
|
raise ValueError("`controls` must contain 3D frames in CHW format.")
|
||||||
width = int((height + 16) * (frame.shape[2] / frame.shape[1])) # NOTE: assuming C H W
|
width = int((height + 16) * (frame.shape[2] / frame.shape[1])) # NOTE: assuming C H W
|
||||||
|
|
||||||
# Check inputs. Raise error if not correct
|
num_frames_per_chunk = self.check_inputs(
|
||||||
self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs)
|
prompt,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
prompt_embeds,
|
||||||
|
callback_on_step_end_tensor_inputs,
|
||||||
|
num_ar_conditional_frames,
|
||||||
|
num_ar_latent_conditional_frames,
|
||||||
|
num_frames_per_chunk,
|
||||||
|
num_frames,
|
||||||
|
conditional_frame_timestep,
|
||||||
|
)
|
||||||
|
|
||||||
|
if num_ar_latent_conditional_frames is not None:
|
||||||
|
num_cond_latent_frames = num_ar_latent_conditional_frames
|
||||||
|
num_ar_conditional_frames = max(0, (num_cond_latent_frames - 1) * self.vae_scale_factor_temporal + 1)
|
||||||
|
else:
|
||||||
|
num_cond_latent_frames = max(0, (num_ar_conditional_frames - 1) // self.vae_scale_factor_temporal + 1)
|
||||||
|
|
||||||
self._guidance_scale = guidance_scale
|
self._guidance_scale = guidance_scale
|
||||||
self._current_timestep = None
|
self._current_timestep = None
|
||||||
@@ -709,102 +768,137 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
|||||||
vae_dtype = self.vae.dtype
|
vae_dtype = self.vae.dtype
|
||||||
transformer_dtype = self.transformer.dtype
|
transformer_dtype = self.transformer.dtype
|
||||||
|
|
||||||
img_context = torch.zeros(
|
if getattr(self.transformer.config, "img_context_dim_in", None):
|
||||||
batch_size,
|
img_context = torch.zeros(
|
||||||
self.transformer.config.img_context_num_tokens,
|
batch_size,
|
||||||
self.transformer.config.img_context_dim_in,
|
self.transformer.config.img_context_num_tokens,
|
||||||
device=prompt_embeds.device,
|
self.transformer.config.img_context_dim_in,
|
||||||
dtype=transformer_dtype,
|
device=prompt_embeds.device,
|
||||||
)
|
|
||||||
encoder_hidden_states = (prompt_embeds, img_context)
|
|
||||||
neg_encoder_hidden_states = (negative_prompt_embeds, img_context)
|
|
||||||
|
|
||||||
num_frames_in = None
|
|
||||||
if image is not None:
|
|
||||||
if batch_size != 1:
|
|
||||||
raise ValueError(f"batch_size must be 1 for image input (given {batch_size})")
|
|
||||||
|
|
||||||
image = torchvision.transforms.functional.to_tensor(image).unsqueeze(0)
|
|
||||||
video = torch.cat([image, torch.zeros_like(image).repeat(num_frames - 1, 1, 1, 1)], dim=0)
|
|
||||||
video = video.unsqueeze(0)
|
|
||||||
num_frames_in = 1
|
|
||||||
elif video is None:
|
|
||||||
video = torch.zeros(batch_size, num_frames, 3, height, width, dtype=torch.uint8)
|
|
||||||
num_frames_in = 0
|
|
||||||
else:
|
|
||||||
num_frames_in = len(video)
|
|
||||||
|
|
||||||
if batch_size != 1:
|
|
||||||
raise ValueError(f"batch_size must be 1 for video input (given {batch_size})")
|
|
||||||
|
|
||||||
assert video is not None
|
|
||||||
video = self.video_processor.preprocess_video(video, height, width)
|
|
||||||
|
|
||||||
# pad with last frame (for video2world)
|
|
||||||
num_frames_out = num_frames
|
|
||||||
video = _maybe_pad_video(video, num_frames_out)
|
|
||||||
assert num_frames_in <= num_frames_out, f"expected ({num_frames_in=}) <= ({num_frames_out=})"
|
|
||||||
|
|
||||||
video = video.to(device=device, dtype=vae_dtype)
|
|
||||||
|
|
||||||
num_channels_latents = self.transformer.config.in_channels - 1
|
|
||||||
latents, cond_latent, cond_mask, cond_indicator = self.prepare_latents(
|
|
||||||
video=video,
|
|
||||||
batch_size=batch_size * num_videos_per_prompt,
|
|
||||||
num_channels_latents=num_channels_latents,
|
|
||||||
height=height,
|
|
||||||
width=width,
|
|
||||||
num_frames_in=num_frames_in,
|
|
||||||
num_frames_out=num_frames,
|
|
||||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=device,
|
|
||||||
generator=generator,
|
|
||||||
latents=latents,
|
|
||||||
)
|
|
||||||
cond_timestep = torch.ones_like(cond_indicator) * conditional_frame_timestep
|
|
||||||
cond_mask = cond_mask.to(transformer_dtype)
|
|
||||||
|
|
||||||
controls_latents = None
|
|
||||||
if controls is not None:
|
|
||||||
controls_latents = self._encode_controls(
|
|
||||||
controls,
|
|
||||||
height=height,
|
|
||||||
width=width,
|
|
||||||
num_frames=num_frames,
|
|
||||||
dtype=transformer_dtype,
|
dtype=transformer_dtype,
|
||||||
device=device,
|
|
||||||
generator=generator,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
|
if num_videos_per_prompt > 1:
|
||||||
|
img_context = img_context.repeat_interleave(num_videos_per_prompt, dim=0)
|
||||||
|
|
||||||
# Denoising loop
|
encoder_hidden_states = (prompt_embeds, img_context)
|
||||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
neg_encoder_hidden_states = (negative_prompt_embeds, img_context)
|
||||||
timesteps = self.scheduler.timesteps
|
else:
|
||||||
self._num_timesteps = len(timesteps)
|
encoder_hidden_states = prompt_embeds
|
||||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
neg_encoder_hidden_states = negative_prompt_embeds
|
||||||
|
|
||||||
gt_velocity = (latents - cond_latent) * cond_mask
|
control_video = self.video_processor.preprocess_video(controls, height, width)
|
||||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
if control_video.shape[0] != batch_size:
|
||||||
for i, t in enumerate(timesteps):
|
if control_video.shape[0] == 1:
|
||||||
if self.interrupt:
|
control_video = control_video.repeat(batch_size, 1, 1, 1, 1)
|
||||||
continue
|
else:
|
||||||
|
raise ValueError(
|
||||||
self._current_timestep = t.cpu().item()
|
f"Expected controls batch size {batch_size} to match prompt batch size, but got {control_video.shape[0]}."
|
||||||
|
|
||||||
# NOTE: assumes sigma(t) \in [0, 1]
|
|
||||||
sigma_t = (
|
|
||||||
torch.tensor(self.scheduler.sigmas[i].item())
|
|
||||||
.unsqueeze(0)
|
|
||||||
.to(device=device, dtype=transformer_dtype)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
in_latents = cond_mask * cond_latent + (1 - cond_mask) * latents
|
num_frames_out = control_video.shape[2]
|
||||||
in_latents = in_latents.to(transformer_dtype)
|
if num_frames is not None:
|
||||||
in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator) * sigma_t
|
num_frames_out = min(num_frames_out, num_frames)
|
||||||
control_blocks = None
|
|
||||||
if controls_latents is not None and self.controlnet is not None:
|
control_video = _maybe_pad_or_trim_video(control_video, num_frames_out)
|
||||||
|
|
||||||
|
# chunk information
|
||||||
|
num_latent_frames_per_chunk = (num_frames_per_chunk - 1) // self.vae_scale_factor_temporal + 1
|
||||||
|
chunk_stride = num_frames_per_chunk - num_ar_conditional_frames
|
||||||
|
chunk_idxs = [
|
||||||
|
(start_idx, min(start_idx + num_frames_per_chunk, num_frames_out))
|
||||||
|
for start_idx in range(0, num_frames_out - num_ar_conditional_frames, chunk_stride)
|
||||||
|
]
|
||||||
|
|
||||||
|
video_chunks = []
|
||||||
|
latents_mean = self.latents_mean.to(dtype=vae_dtype, device=device)
|
||||||
|
latents_std = self.latents_std.to(dtype=vae_dtype, device=device)
|
||||||
|
|
||||||
|
def decode_latents(latents):
|
||||||
|
latents = latents * latents_std + latents_mean
|
||||||
|
video = self.vae.decode(latents.to(dtype=self.vae.dtype, device=device), return_dict=False)[0]
|
||||||
|
return video
|
||||||
|
|
||||||
|
latents_arg = latents
|
||||||
|
initial_num_cond_latent_frames = 0
|
||||||
|
latent_chunks = []
|
||||||
|
num_chunks = len(chunk_idxs)
|
||||||
|
total_steps = num_inference_steps * num_chunks
|
||||||
|
with self.progress_bar(total=total_steps) as progress_bar:
|
||||||
|
for chunk_idx, (start_idx, end_idx) in enumerate(chunk_idxs):
|
||||||
|
if chunk_idx == 0:
|
||||||
|
prev_output = torch.zeros((batch_size, num_frames_per_chunk, 3, height, width), dtype=vae_dtype)
|
||||||
|
prev_output = self.video_processor.preprocess_video(prev_output, height, width)
|
||||||
|
else:
|
||||||
|
prev_output = video_chunks[-1].clone()
|
||||||
|
if num_ar_conditional_frames > 0:
|
||||||
|
prev_output[:, :, :num_ar_conditional_frames] = prev_output[:, :, -num_ar_conditional_frames:]
|
||||||
|
prev_output[:, :, num_ar_conditional_frames:] = -1 # -1 == 0 in processed video space
|
||||||
|
else:
|
||||||
|
prev_output.fill_(-1)
|
||||||
|
|
||||||
|
chunk_video = prev_output.to(device=device, dtype=vae_dtype)
|
||||||
|
chunk_video = _maybe_pad_or_trim_video(chunk_video, num_frames_per_chunk)
|
||||||
|
latents, cond_latent, cond_mask, cond_indicator = self.prepare_latents(
|
||||||
|
video=chunk_video,
|
||||||
|
batch_size=batch_size * num_videos_per_prompt,
|
||||||
|
num_channels_latents=self.transformer.config.in_channels - 1,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
num_frames_in=chunk_video.shape[2],
|
||||||
|
num_frames_out=num_frames_per_chunk,
|
||||||
|
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=device,
|
||||||
|
generator=generator,
|
||||||
|
num_cond_latent_frames=initial_num_cond_latent_frames
|
||||||
|
if chunk_idx == 0
|
||||||
|
else num_cond_latent_frames,
|
||||||
|
latents=latents_arg,
|
||||||
|
)
|
||||||
|
cond_mask = cond_mask.to(transformer_dtype)
|
||||||
|
cond_timestep = torch.ones_like(cond_indicator) * conditional_frame_timestep
|
||||||
|
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
|
||||||
|
|
||||||
|
chunk_control_video = control_video[:, :, start_idx:end_idx, ...].to(
|
||||||
|
device=device, dtype=self.vae.dtype
|
||||||
|
)
|
||||||
|
chunk_control_video = _maybe_pad_or_trim_video(chunk_control_video, num_frames_per_chunk)
|
||||||
|
if isinstance(generator, list):
|
||||||
|
controls_latents = [
|
||||||
|
retrieve_latents(self.vae.encode(chunk_control_video[i].unsqueeze(0)), generator=generator[i])
|
||||||
|
for i in range(chunk_control_video.shape[0])
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
controls_latents = [
|
||||||
|
retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator=generator)
|
||||||
|
for vid in chunk_control_video
|
||||||
|
]
|
||||||
|
controls_latents = torch.cat(controls_latents, dim=0).to(transformer_dtype)
|
||||||
|
|
||||||
|
controls_latents = (controls_latents - latents_mean) / latents_std
|
||||||
|
|
||||||
|
# Denoising loop
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||||
|
timesteps = self.scheduler.timesteps
|
||||||
|
self._num_timesteps = len(timesteps)
|
||||||
|
|
||||||
|
gt_velocity = (latents - cond_latent) * cond_mask
|
||||||
|
for i, t in enumerate(timesteps):
|
||||||
|
if self.interrupt:
|
||||||
|
continue
|
||||||
|
|
||||||
|
self._current_timestep = t.cpu().item()
|
||||||
|
|
||||||
|
# NOTE: assumes sigma(t) \in [0, 1]
|
||||||
|
sigma_t = (
|
||||||
|
torch.tensor(self.scheduler.sigmas[i].item())
|
||||||
|
.unsqueeze(0)
|
||||||
|
.to(device=device, dtype=transformer_dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
in_latents = cond_mask * cond_latent + (1 - cond_mask) * latents
|
||||||
|
in_latents = in_latents.to(transformer_dtype)
|
||||||
|
in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator) * sigma_t
|
||||||
control_output = self.controlnet(
|
control_output = self.controlnet(
|
||||||
controls_latents=controls_latents,
|
controls_latents=controls_latents,
|
||||||
latents=in_latents,
|
latents=in_latents,
|
||||||
@@ -817,20 +911,18 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
|||||||
)
|
)
|
||||||
control_blocks = control_output[0]
|
control_blocks = control_output[0]
|
||||||
|
|
||||||
noise_pred = self.transformer(
|
noise_pred = self.transformer(
|
||||||
hidden_states=in_latents,
|
hidden_states=in_latents,
|
||||||
timestep=in_timestep,
|
timestep=in_timestep,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
block_controlnet_hidden_states=control_blocks,
|
block_controlnet_hidden_states=control_blocks,
|
||||||
condition_mask=cond_mask,
|
condition_mask=cond_mask,
|
||||||
padding_mask=padding_mask,
|
padding_mask=padding_mask,
|
||||||
return_dict=False,
|
return_dict=False,
|
||||||
)[0]
|
)[0]
|
||||||
noise_pred = gt_velocity + noise_pred * (1 - cond_mask)
|
noise_pred = gt_velocity + noise_pred * (1 - cond_mask)
|
||||||
|
|
||||||
if self.do_classifier_free_guidance:
|
if self.do_classifier_free_guidance:
|
||||||
control_blocks = None
|
|
||||||
if controls_latents is not None and self.controlnet is not None:
|
|
||||||
control_output = self.controlnet(
|
control_output = self.controlnet(
|
||||||
controls_latents=controls_latents,
|
controls_latents=controls_latents,
|
||||||
latents=in_latents,
|
latents=in_latents,
|
||||||
@@ -843,46 +935,50 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
|||||||
)
|
)
|
||||||
control_blocks = control_output[0]
|
control_blocks = control_output[0]
|
||||||
|
|
||||||
noise_pred_neg = self.transformer(
|
noise_pred_neg = self.transformer(
|
||||||
hidden_states=in_latents,
|
hidden_states=in_latents,
|
||||||
timestep=in_timestep,
|
timestep=in_timestep,
|
||||||
encoder_hidden_states=neg_encoder_hidden_states, # NOTE: negative prompt
|
encoder_hidden_states=neg_encoder_hidden_states, # NOTE: negative prompt
|
||||||
block_controlnet_hidden_states=control_blocks,
|
block_controlnet_hidden_states=control_blocks,
|
||||||
condition_mask=cond_mask,
|
condition_mask=cond_mask,
|
||||||
padding_mask=padding_mask,
|
padding_mask=padding_mask,
|
||||||
return_dict=False,
|
return_dict=False,
|
||||||
)[0]
|
)[0]
|
||||||
# NOTE: replace velocity (noise_pred_neg) with gt_velocity for conditioning inputs only
|
# NOTE: replace velocity (noise_pred_neg) with gt_velocity for conditioning inputs only
|
||||||
noise_pred_neg = gt_velocity + noise_pred_neg * (1 - cond_mask)
|
noise_pred_neg = gt_velocity + noise_pred_neg * (1 - cond_mask)
|
||||||
noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_neg)
|
noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_neg)
|
||||||
|
|
||||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||||
|
|
||||||
if callback_on_step_end is not None:
|
# call the callback, if provided
|
||||||
callback_kwargs = {}
|
if callback_on_step_end is not None:
|
||||||
for k in callback_on_step_end_tensor_inputs:
|
callback_kwargs = {}
|
||||||
callback_kwargs[k] = locals()[k]
|
for k in callback_on_step_end_tensor_inputs:
|
||||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
callback_kwargs[k] = locals()[k]
|
||||||
|
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||||
|
|
||||||
latents = callback_outputs.pop("latents", latents)
|
latents = callback_outputs.pop("latents", latents)
|
||||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||||
|
|
||||||
# call the callback, if provided
|
if i == total_steps - 1 or ((i + 1) % self.scheduler.order == 0):
|
||||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
progress_bar.update()
|
||||||
progress_bar.update()
|
|
||||||
|
|
||||||
if XLA_AVAILABLE:
|
if XLA_AVAILABLE:
|
||||||
xm.mark_step()
|
xm.mark_step()
|
||||||
|
|
||||||
|
video_chunks.append(decode_latents(latents).detach().cpu())
|
||||||
|
latent_chunks.append(latents.detach().cpu())
|
||||||
|
|
||||||
self._current_timestep = None
|
self._current_timestep = None
|
||||||
|
|
||||||
if not output_type == "latent":
|
if not output_type == "latent":
|
||||||
latents_mean = self.latents_mean.to(latents.device, latents.dtype)
|
video_chunks = [
|
||||||
latents_std = self.latents_std.to(latents.device, latents.dtype)
|
chunk[:, :, num_ar_conditional_frames:, ...] if chunk_idx != 0 else chunk
|
||||||
latents = latents * latents_std + latents_mean
|
for chunk_idx, chunk in enumerate(video_chunks)
|
||||||
video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
|
]
|
||||||
video = self._match_num_frames(video, num_frames)
|
video = torch.cat(video_chunks, dim=2)
|
||||||
|
video = video[:, :, :num_frames_out, ...]
|
||||||
|
|
||||||
assert self.safety_checker is not None
|
assert self.safety_checker is not None
|
||||||
self.safety_checker.to(device)
|
self.safety_checker.to(device)
|
||||||
@@ -899,7 +995,13 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
|||||||
video = torch.from_numpy(video).permute(0, 4, 1, 2, 3)
|
video = torch.from_numpy(video).permute(0, 4, 1, 2, 3)
|
||||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||||
else:
|
else:
|
||||||
video = latents
|
latent_T = (num_frames_out - 1) // self.vae_scale_factor_temporal + 1
|
||||||
|
latent_chunks = [
|
||||||
|
chunk[:, :, num_cond_latent_frames:, ...] if chunk_idx != 0 else chunk
|
||||||
|
for chunk_idx, chunk in enumerate(latent_chunks)
|
||||||
|
]
|
||||||
|
video = torch.cat(latent_chunks, dim=2)
|
||||||
|
video = video[:, :, :latent_T, ...]
|
||||||
|
|
||||||
# Offload all models
|
# Offload all models
|
||||||
self.maybe_free_model_hooks()
|
self.maybe_free_model_hooks()
|
||||||
@@ -908,19 +1010,3 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
|||||||
return (video,)
|
return (video,)
|
||||||
|
|
||||||
return CosmosPipelineOutput(frames=video)
|
return CosmosPipelineOutput(frames=video)
|
||||||
|
|
||||||
def _match_num_frames(self, video: torch.Tensor, target_num_frames: int) -> torch.Tensor:
|
|
||||||
if target_num_frames <= 0 or video.shape[2] == target_num_frames:
|
|
||||||
return video
|
|
||||||
|
|
||||||
frames_per_latent = max(self.vae_scale_factor_temporal, 1)
|
|
||||||
video = torch.repeat_interleave(video, repeats=frames_per_latent, dim=2)
|
|
||||||
|
|
||||||
current_frames = video.shape[2]
|
|
||||||
if current_frames < target_num_frames:
|
|
||||||
pad = video[:, :, -1:, :, :].repeat(1, 1, target_num_frames - current_frames, 1, 1)
|
|
||||||
video = torch.cat([video, pad], dim=2)
|
|
||||||
elif current_frames > target_num_frames:
|
|
||||||
video = video[:, :, :target_num_frames]
|
|
||||||
|
|
||||||
return video
|
|
||||||
|
|||||||
@@ -699,9 +699,13 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
|||||||
mask_shape = (batch_size, 1, num_frames, height, width)
|
mask_shape = (batch_size, 1, num_frames, height, width)
|
||||||
|
|
||||||
if latents is not None:
|
if latents is not None:
|
||||||
conditioning_mask = latents.new_zeros(mask_shape)
|
|
||||||
conditioning_mask[:, :, 0] = 1.0
|
|
||||||
if latents.ndim == 5:
|
if latents.ndim == 5:
|
||||||
|
# conditioning_mask needs to the same shape as latents in two stages generation.
|
||||||
|
batch_size, _, num_frames, height, width = latents.shape
|
||||||
|
mask_shape = (batch_size, 1, num_frames, height, width)
|
||||||
|
conditioning_mask = latents.new_zeros(mask_shape)
|
||||||
|
conditioning_mask[:, :, 0] = 1.0
|
||||||
|
|
||||||
latents = self._normalize_latents(
|
latents = self._normalize_latents(
|
||||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||||
)
|
)
|
||||||
@@ -710,6 +714,9 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
|||||||
latents = self._pack_latents(
|
latents = self._pack_latents(
|
||||||
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
|
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
conditioning_mask = latents.new_zeros(mask_shape)
|
||||||
|
conditioning_mask[:, :, 0] = 1.0
|
||||||
conditioning_mask = self._pack_latents(
|
conditioning_mask = self._pack_latents(
|
||||||
conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
|
conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
|
||||||
).squeeze(-1)
|
).squeeze(-1)
|
||||||
|
|||||||
@@ -276,7 +276,7 @@ class ZImagePipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMix
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def do_classifier_free_guidance(self):
|
def do_classifier_free_guidance(self):
|
||||||
return self._guidance_scale > 1
|
return self._guidance_scale > 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def joint_attention_kwargs(self):
|
def joint_attention_kwargs(self):
|
||||||
|
|||||||
@@ -299,7 +299,10 @@ def get_cached_module_file(
|
|||||||
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
|
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
|
||||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||||
|
|
||||||
module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
|
if subfolder is not None:
|
||||||
|
module_file_or_url = os.path.join(pretrained_model_name_or_path, subfolder, module_file)
|
||||||
|
else:
|
||||||
|
module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
|
||||||
|
|
||||||
if os.path.isfile(module_file_or_url):
|
if os.path.isfile(module_file_or_url):
|
||||||
resolved_module_file = module_file_or_url
|
resolved_module_file = module_file_or_url
|
||||||
@@ -384,7 +387,11 @@ def get_cached_module_file(
|
|||||||
if not os.path.exists(submodule_path / module_folder):
|
if not os.path.exists(submodule_path / module_folder):
|
||||||
os.makedirs(submodule_path / module_folder)
|
os.makedirs(submodule_path / module_folder)
|
||||||
module_needed = f"{module_needed}.py"
|
module_needed = f"{module_needed}.py"
|
||||||
shutil.copyfile(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed)
|
if subfolder is not None:
|
||||||
|
source_path = os.path.join(pretrained_model_name_or_path, subfolder, module_needed)
|
||||||
|
else:
|
||||||
|
source_path = os.path.join(pretrained_model_name_or_path, module_needed)
|
||||||
|
shutil.copyfile(source_path, submodule_path / module_needed)
|
||||||
else:
|
else:
|
||||||
# Get the commit hash
|
# Get the commit hash
|
||||||
# TODO: we will get this info in the etag soon, so retrieve it from there and not here.
|
# TODO: we will get this info in the etag soon, so retrieve it from there and not here.
|
||||||
|
|||||||
@@ -131,6 +131,26 @@ class CosmosControlNetModelTests(ModelTesterMixin, unittest.TestCase):
|
|||||||
self.assertIsInstance(output[0], list)
|
self.assertIsInstance(output[0], list)
|
||||||
self.assertEqual(len(output[0]), init_dict["n_controlnet_blocks"])
|
self.assertEqual(len(output[0]), init_dict["n_controlnet_blocks"])
|
||||||
|
|
||||||
|
def test_condition_mask_changes_output(self):
|
||||||
|
"""Test that condition mask affects control outputs."""
|
||||||
|
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||||
|
model = self.model_class(**init_dict)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
inputs_no_mask = dict(inputs_dict)
|
||||||
|
inputs_no_mask["condition_mask"] = torch.zeros_like(inputs_dict["condition_mask"])
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output_no_mask = model(**inputs_no_mask)
|
||||||
|
output_with_mask = model(**inputs_dict)
|
||||||
|
|
||||||
|
self.assertEqual(len(output_no_mask.control_block_samples), len(output_with_mask.control_block_samples))
|
||||||
|
for no_mask_tensor, with_mask_tensor in zip(
|
||||||
|
output_no_mask.control_block_samples, output_with_mask.control_block_samples
|
||||||
|
):
|
||||||
|
self.assertFalse(torch.allclose(no_mask_tensor, with_mask_tensor))
|
||||||
|
|
||||||
def test_conditioning_scale_single(self):
|
def test_conditioning_scale_single(self):
|
||||||
"""Test that a single conditioning scale is broadcast to all blocks."""
|
"""Test that a single conditioning scale is broadcast to all blocks."""
|
||||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||||
|
|||||||
@@ -1,6 +1,10 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import torch
|
||||||
from transformers import CLIPTextModel, LongformerModel
|
from transformers import CLIPTextModel, LongformerModel
|
||||||
|
|
||||||
from diffusers.models import AutoModel, UNet2DConditionModel
|
from diffusers.models import AutoModel, UNet2DConditionModel
|
||||||
@@ -35,6 +39,45 @@ class TestAutoModel(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
assert isinstance(model, CLIPTextModel)
|
assert isinstance(model, CLIPTextModel)
|
||||||
|
|
||||||
|
def test_load_dynamic_module_from_local_path_with_subfolder(self):
|
||||||
|
CUSTOM_MODEL_CODE = (
|
||||||
|
"import torch\n"
|
||||||
|
"from diffusers import ModelMixin, ConfigMixin\n"
|
||||||
|
"from diffusers.configuration_utils import register_to_config\n"
|
||||||
|
"\n"
|
||||||
|
"class CustomModel(ModelMixin, ConfigMixin):\n"
|
||||||
|
" @register_to_config\n"
|
||||||
|
" def __init__(self, hidden_size=8):\n"
|
||||||
|
" super().__init__()\n"
|
||||||
|
" self.linear = torch.nn.Linear(hidden_size, hidden_size)\n"
|
||||||
|
"\n"
|
||||||
|
" def forward(self, x):\n"
|
||||||
|
" return self.linear(x)\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
subfolder = "custom_model"
|
||||||
|
model_dir = os.path.join(tmpdir, subfolder)
|
||||||
|
os.makedirs(model_dir)
|
||||||
|
|
||||||
|
with open(os.path.join(model_dir, "modeling.py"), "w") as f:
|
||||||
|
f.write(CUSTOM_MODEL_CODE)
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"_class_name": "CustomModel",
|
||||||
|
"_diffusers_version": "0.0.0",
|
||||||
|
"auto_map": {"AutoModel": "modeling.CustomModel"},
|
||||||
|
"hidden_size": 8,
|
||||||
|
}
|
||||||
|
with open(os.path.join(model_dir, "config.json"), "w") as f:
|
||||||
|
json.dump(config, f)
|
||||||
|
|
||||||
|
torch.save({}, os.path.join(model_dir, "diffusion_pytorch_model.bin"))
|
||||||
|
|
||||||
|
model = AutoModel.from_pretrained(tmpdir, subfolder=subfolder, trust_remote_code=True)
|
||||||
|
assert model.__class__.__name__ == "CustomModel"
|
||||||
|
assert model.config["hidden_size"] == 8
|
||||||
|
|
||||||
|
|
||||||
class TestAutoModelFromConfig(unittest.TestCase):
|
class TestAutoModelFromConfig(unittest.TestCase):
|
||||||
@patch(
|
@patch(
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ class Cosmos2_5_TransferWrapper(Cosmos2_5_TransferPipeline):
|
|||||||
class Cosmos2_5_TransferPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
class Cosmos2_5_TransferPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||||
pipeline_class = Cosmos2_5_TransferWrapper
|
pipeline_class = Cosmos2_5_TransferWrapper
|
||||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"controls"})
|
||||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||||
required_optional_params = frozenset(
|
required_optional_params = frozenset(
|
||||||
@@ -176,15 +176,19 @@ class Cosmos2_5_TransferPipelineFastTests(PipelineTesterMixin, unittest.TestCase
|
|||||||
else:
|
else:
|
||||||
generator = torch.Generator(device=device).manual_seed(seed)
|
generator = torch.Generator(device=device).manual_seed(seed)
|
||||||
|
|
||||||
|
controls_generator = torch.Generator(device="cpu").manual_seed(seed)
|
||||||
|
|
||||||
inputs = {
|
inputs = {
|
||||||
"prompt": "dance monkey",
|
"prompt": "dance monkey",
|
||||||
"negative_prompt": "bad quality",
|
"negative_prompt": "bad quality",
|
||||||
|
"controls": [torch.randn(3, 32, 32, generator=controls_generator) for _ in range(5)],
|
||||||
"generator": generator,
|
"generator": generator,
|
||||||
"num_inference_steps": 2,
|
"num_inference_steps": 2,
|
||||||
"guidance_scale": 3.0,
|
"guidance_scale": 3.0,
|
||||||
"height": 32,
|
"height": 32,
|
||||||
"width": 32,
|
"width": 32,
|
||||||
"num_frames": 3,
|
"num_frames": 3,
|
||||||
|
"num_frames_per_chunk": 16,
|
||||||
"max_sequence_length": 16,
|
"max_sequence_length": 16,
|
||||||
"output_type": "pt",
|
"output_type": "pt",
|
||||||
}
|
}
|
||||||
@@ -212,6 +216,56 @@ class Cosmos2_5_TransferPipelineFastTests(PipelineTesterMixin, unittest.TestCase
|
|||||||
self.assertEqual(generated_video.shape, (3, 3, 32, 32))
|
self.assertEqual(generated_video.shape, (3, 3, 32, 32))
|
||||||
self.assertTrue(torch.isfinite(generated_video).all())
|
self.assertTrue(torch.isfinite(generated_video).all())
|
||||||
|
|
||||||
|
def test_inference_autoregressive_multi_chunk(self):
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
pipe = self.pipeline_class(**components)
|
||||||
|
pipe.to(device)
|
||||||
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
inputs = self.get_dummy_inputs(device)
|
||||||
|
inputs["num_frames"] = 5
|
||||||
|
inputs["num_frames_per_chunk"] = 3
|
||||||
|
inputs["num_ar_conditional_frames"] = 1
|
||||||
|
|
||||||
|
video = pipe(**inputs).frames
|
||||||
|
generated_video = video[0]
|
||||||
|
self.assertEqual(generated_video.shape, (5, 3, 32, 32))
|
||||||
|
self.assertTrue(torch.isfinite(generated_video).all())
|
||||||
|
|
||||||
|
def test_inference_autoregressive_multi_chunk_no_condition_frames(self):
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
pipe = self.pipeline_class(**components)
|
||||||
|
pipe.to(device)
|
||||||
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
inputs = self.get_dummy_inputs(device)
|
||||||
|
inputs["num_frames"] = 5
|
||||||
|
inputs["num_frames_per_chunk"] = 3
|
||||||
|
inputs["num_ar_conditional_frames"] = 0
|
||||||
|
|
||||||
|
video = pipe(**inputs).frames
|
||||||
|
generated_video = video[0]
|
||||||
|
self.assertEqual(generated_video.shape, (5, 3, 32, 32))
|
||||||
|
self.assertTrue(torch.isfinite(generated_video).all())
|
||||||
|
|
||||||
|
def test_num_frames_per_chunk_above_rope_raises(self):
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
pipe = self.pipeline_class(**components)
|
||||||
|
pipe.to(device)
|
||||||
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
inputs = self.get_dummy_inputs(device)
|
||||||
|
inputs["num_frames_per_chunk"] = 17
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(ValueError, "too large for RoPE setting"):
|
||||||
|
pipe(**inputs)
|
||||||
|
|
||||||
def test_inference_with_controls(self):
|
def test_inference_with_controls(self):
|
||||||
"""Test inference with control inputs (ControlNet)."""
|
"""Test inference with control inputs (ControlNet)."""
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
@@ -222,13 +276,13 @@ class Cosmos2_5_TransferPipelineFastTests(PipelineTesterMixin, unittest.TestCase
|
|||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
inputs = self.get_dummy_inputs(device)
|
inputs = self.get_dummy_inputs(device)
|
||||||
# Add control video input - should be a video tensor
|
inputs["controls"] = [torch.randn(3, 32, 32) for _ in range(5)] # list of 5 frames (C, H, W)
|
||||||
inputs["controls"] = [torch.randn(3, 3, 32, 32)] # num_frames, channels, height, width
|
|
||||||
inputs["controls_conditioning_scale"] = 1.0
|
inputs["controls_conditioning_scale"] = 1.0
|
||||||
|
inputs["num_frames"] = None
|
||||||
|
|
||||||
video = pipe(**inputs).frames
|
video = pipe(**inputs).frames
|
||||||
generated_video = video[0]
|
generated_video = video[0]
|
||||||
self.assertEqual(generated_video.shape, (3, 3, 32, 32))
|
self.assertEqual(generated_video.shape, (5, 3, 32, 32))
|
||||||
self.assertTrue(torch.isfinite(generated_video).all())
|
self.assertTrue(torch.isfinite(generated_video).all())
|
||||||
|
|
||||||
def test_callback_inputs(self):
|
def test_callback_inputs(self):
|
||||||
|
|||||||
@@ -24,7 +24,8 @@ from diffusers import (
|
|||||||
LTX2ImageToVideoPipeline,
|
LTX2ImageToVideoPipeline,
|
||||||
LTX2VideoTransformer3DModel,
|
LTX2VideoTransformer3DModel,
|
||||||
)
|
)
|
||||||
from diffusers.pipelines.ltx2 import LTX2TextConnectors
|
from diffusers.pipelines.ltx2 import LTX2LatentUpsamplePipeline, LTX2TextConnectors
|
||||||
|
from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
|
||||||
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
|
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
|
||||||
|
|
||||||
from ...testing_utils import enable_full_determinism
|
from ...testing_utils import enable_full_determinism
|
||||||
@@ -174,6 +175,15 @@ class LTX2ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
return components
|
return components
|
||||||
|
|
||||||
|
def get_dummy_upsample_component(self, in_channels=4, mid_channels=32, num_blocks_per_stage=1):
|
||||||
|
upsampler = LTX2LatentUpsamplerModel(
|
||||||
|
in_channels=in_channels,
|
||||||
|
mid_channels=mid_channels,
|
||||||
|
num_blocks_per_stage=num_blocks_per_stage,
|
||||||
|
)
|
||||||
|
|
||||||
|
return upsampler
|
||||||
|
|
||||||
def get_dummy_inputs(self, device, seed=0):
|
def get_dummy_inputs(self, device, seed=0):
|
||||||
if str(device).startswith("mps"):
|
if str(device).startswith("mps"):
|
||||||
generator = torch.manual_seed(seed)
|
generator = torch.manual_seed(seed)
|
||||||
@@ -287,5 +297,60 @@ class LTX2ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
|||||||
assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4)
|
assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4)
|
||||||
assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4)
|
assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4)
|
||||||
|
|
||||||
|
def test_two_stages_inference_with_upsampler(self):
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
pipe = self.pipeline_class(**components)
|
||||||
|
pipe.to(device)
|
||||||
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
inputs = self.get_dummy_inputs(device)
|
||||||
|
inputs["output_type"] = "latent"
|
||||||
|
first_stage_output = pipe(**inputs)
|
||||||
|
video_latent = first_stage_output.frames
|
||||||
|
audio_latent = first_stage_output.audio
|
||||||
|
|
||||||
|
self.assertEqual(video_latent.shape, (1, 4, 3, 16, 16))
|
||||||
|
self.assertEqual(audio_latent.shape, (1, 2, 5, 2))
|
||||||
|
self.assertEqual(audio_latent.shape[1], components["vocoder"].config.out_channels)
|
||||||
|
|
||||||
|
upsampler = self.get_dummy_upsample_component(in_channels=video_latent.shape[1])
|
||||||
|
upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=upsampler)
|
||||||
|
upscaled_video_latent = upsample_pipe(latents=video_latent, output_type="latent", return_dict=False)[0]
|
||||||
|
self.assertEqual(upscaled_video_latent.shape, (1, 4, 3, 32, 32))
|
||||||
|
|
||||||
|
inputs["latents"] = upscaled_video_latent
|
||||||
|
inputs["audio_latents"] = audio_latent
|
||||||
|
inputs["output_type"] = "pt"
|
||||||
|
second_stage_output = pipe(**inputs)
|
||||||
|
video = second_stage_output.frames
|
||||||
|
audio = second_stage_output.audio
|
||||||
|
|
||||||
|
self.assertEqual(video.shape, (1, 5, 3, 64, 64))
|
||||||
|
self.assertEqual(audio.shape[0], 1)
|
||||||
|
self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
expected_video_slice = torch.tensor(
|
||||||
|
[
|
||||||
|
0.4497, 0.6757, 0.4219, 0.7686, 0.4525, 0.6483, 0.3969, 0.7404, 0.3541, 0.3039, 0.4592, 0.3521, 0.3665, 0.2785, 0.3336, 0.3079
|
||||||
|
]
|
||||||
|
)
|
||||||
|
expected_audio_slice = torch.tensor(
|
||||||
|
[
|
||||||
|
0.0271, 0.0492, 0.1249, 0.1126, 0.1661, 0.1060, 0.1717, 0.0944, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
video = video.flatten()
|
||||||
|
audio = audio.flatten()
|
||||||
|
generated_video_slice = torch.cat([video[:8], video[-8:]])
|
||||||
|
generated_audio_slice = torch.cat([audio[:8], audio[-8:]])
|
||||||
|
|
||||||
|
assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4)
|
||||||
|
assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4)
|
||||||
|
|
||||||
def test_inference_batch_single_identical(self):
|
def test_inference_batch_single_identical(self):
|
||||||
self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=2e-2)
|
self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=2e-2)
|
||||||
|
|||||||
Reference in New Issue
Block a user