mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-02 23:00:39 +08:00
Compare commits
4 Commits
use-fixtur
...
attn-backe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e33e1623f6 | ||
|
|
6648604f0d | ||
|
|
7220687151 | ||
|
|
73b159f2a1 |
3
.github/workflows/pypi_publish.yaml
vendored
3
.github/workflows/pypi_publish.yaml
vendored
@@ -54,6 +54,7 @@ jobs:
|
||||
python -m pip install --upgrade pip
|
||||
pip install -U setuptools wheel twine
|
||||
pip install -U torch --index-url https://download.pytorch.org/whl/cpu
|
||||
pip install -U transformers
|
||||
|
||||
- name: Build the dist files
|
||||
run: python setup.py bdist_wheel && python setup.py sdist
|
||||
@@ -68,8 +69,6 @@ jobs:
|
||||
run: |
|
||||
pip install diffusers && pip uninstall diffusers -y
|
||||
pip install -i https://test.pypi.org/simple/ diffusers
|
||||
pip install -U transformers
|
||||
python utils/print_env.py
|
||||
python -c "from diffusers import __version__; print(__version__)"
|
||||
python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('fusing/unet-ldm-dummy-update'); pipe()"
|
||||
python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('hf-internal-testing/tiny-stable-diffusion-pipe', safety_checker=None); pipe('ah suh du')"
|
||||
|
||||
@@ -14,8 +14,4 @@
|
||||
|
||||
## AutoPipelineBlocks
|
||||
|
||||
[[autodoc]] diffusers.modular_pipelines.modular_pipeline.AutoPipelineBlocks
|
||||
|
||||
## ConditionalPipelineBlocks
|
||||
|
||||
[[autodoc]] diffusers.modular_pipelines.modular_pipeline.ConditionalPipelineBlocks
|
||||
[[autodoc]] diffusers.modular_pipelines.modular_pipeline.AutoPipelineBlocks
|
||||
@@ -46,20 +46,6 @@ output = pipe(
|
||||
output.save("output.png")
|
||||
```
|
||||
|
||||
## Cosmos2_5_TransferPipeline
|
||||
|
||||
[[autodoc]] Cosmos2_5_TransferPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
|
||||
## Cosmos2_5_PredictBasePipeline
|
||||
|
||||
[[autodoc]] Cosmos2_5_PredictBasePipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
|
||||
## CosmosTextToWorldPipeline
|
||||
|
||||
[[autodoc]] CosmosTextToWorldPipeline
|
||||
@@ -84,6 +70,12 @@ output.save("output.png")
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## Cosmos2_5_PredictBasePipeline
|
||||
|
||||
[[autodoc]] Cosmos2_5_PredictBasePipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## CosmosPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.cosmos.pipeline_output.CosmosPipelineOutput
|
||||
|
||||
@@ -121,7 +121,7 @@ from diffusers.modular_pipelines import AutoPipelineBlocks
|
||||
|
||||
class AutoImageBlocks(AutoPipelineBlocks):
|
||||
# List of sub-block classes to choose from
|
||||
block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock]
|
||||
block_classes = [block_inpaint_cls, block_i2i_cls, block_t2i_cls]
|
||||
# Names for each block in the same order
|
||||
block_names = ["inpaint", "img2img", "text2img"]
|
||||
# Trigger inputs that determine which block to run
|
||||
@@ -129,8 +129,8 @@ class AutoImageBlocks(AutoPipelineBlocks):
|
||||
# - "image" triggers img2img workflow (but only if mask is not provided)
|
||||
# - if none of above, runs the text2img workflow (default)
|
||||
block_trigger_inputs = ["mask", "image", None]
|
||||
# Description is extremely important for AutoPipelineBlocks
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Pipeline generates images given different types of conditions!\n"
|
||||
@@ -141,7 +141,7 @@ class AutoImageBlocks(AutoPipelineBlocks):
|
||||
)
|
||||
```
|
||||
|
||||
It is **very** important to include a `description` to avoid any confusion over how to run a block and what inputs are required. While [`~modular_pipelines.AutoPipelineBlocks`] are convenient, its conditional logic may be difficult to figure out if it isn't properly explained.
|
||||
It is **very** important to include a `description` to avoid any confusion over how to run a block and what inputs are required. While [`~modular_pipelines.AutoPipelineBlocks`] are convenient, it's conditional logic may be difficult to figure out if it isn't properly explained.
|
||||
|
||||
Create an instance of `AutoImageBlocks`.
|
||||
|
||||
@@ -152,74 +152,5 @@ auto_blocks = AutoImageBlocks()
|
||||
For more complex compositions, such as nested [`~modular_pipelines.AutoPipelineBlocks`] blocks when they're used as sub-blocks in larger pipelines, use the [`~modular_pipelines.SequentialPipelineBlocks.get_execution_blocks`] method to extract the a block that is actually run based on your input.
|
||||
|
||||
```py
|
||||
auto_blocks.get_execution_blocks(mask=True)
|
||||
```
|
||||
|
||||
## ConditionalPipelineBlocks
|
||||
|
||||
[`~modular_pipelines.AutoPipelineBlocks`] is a special case of [`~modular_pipelines.ConditionalPipelineBlocks`]. While [`~modular_pipelines.AutoPipelineBlocks`] selects blocks based on whether a trigger input is provided or not, [`~modular_pipelines.ConditionalPipelineBlocks`] is able to select a block based on custom selection logic provided in the `select_block` method.
|
||||
|
||||
Here is the same example written using [`~modular_pipelines.ConditionalPipelineBlocks`] directly:
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines import ConditionalPipelineBlocks
|
||||
|
||||
class AutoImageBlocks(ConditionalPipelineBlocks):
|
||||
block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock]
|
||||
block_names = ["inpaint", "img2img", "text2img"]
|
||||
block_trigger_inputs = ["mask", "image"]
|
||||
default_block_name = "text2img"
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Pipeline generates images given different types of conditions!\n"
|
||||
+ "This is an auto pipeline block that works for text2img, img2img and inpainting tasks.\n"
|
||||
+ " - inpaint workflow is run when `mask` is provided.\n"
|
||||
+ " - img2img workflow is run when `image` is provided (but only when `mask` is not provided).\n"
|
||||
+ " - text2img workflow is run when neither `image` nor `mask` is provided.\n"
|
||||
)
|
||||
|
||||
def select_block(self, mask=None, image=None) -> str | None:
|
||||
if mask is not None:
|
||||
return "inpaint"
|
||||
if image is not None:
|
||||
return "img2img"
|
||||
return None # falls back to default_block_name ("text2img")
|
||||
```
|
||||
|
||||
The inputs listed in `block_trigger_inputs` are passed as keyword arguments to `select_block()`. When `select_block` returns `None`, it falls back to `default_block_name`. If `default_block_name` is also `None`, the entire conditional block is skipped — this is useful for optional processing steps that should only run when specific inputs are provided.
|
||||
|
||||
## Workflows
|
||||
|
||||
Pipelines that contain conditional blocks ([`~modular_pipelines.AutoPipelineBlocks`] or [`~modular_pipelines.ConditionalPipelineBlocks]`) can support multiple workflows — for example, our SDXL modular pipeline supports a dozen workflows all in one pipeline. But this also means it can be confusing for users to know what workflows are supported and how to run them. For pipeline builders, it's useful to be able to extract only the blocks relevant to a specific workflow.
|
||||
|
||||
We recommend defining a `_workflow_map` to give each workflow a name and explicitly list the inputs it requires.
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines import SequentialPipelineBlocks
|
||||
|
||||
class MyPipelineBlocks(SequentialPipelineBlocks):
|
||||
block_classes = [TextEncoderBlock, AutoImageBlocks, DecodeBlock]
|
||||
block_names = ["text_encoder", "auto_image", "decode"]
|
||||
|
||||
_workflow_map = {
|
||||
"text2image": {"prompt": True},
|
||||
"image2image": {"image": True, "prompt": True},
|
||||
"inpaint": {"mask": True, "image": True, "prompt": True},
|
||||
}
|
||||
```
|
||||
|
||||
All of our built-in modular pipelines come with pre-defined workflows. The `available_workflows` property lists all supported workflows:
|
||||
|
||||
```py
|
||||
pipeline_blocks = MyPipelineBlocks()
|
||||
pipeline_blocks.available_workflows
|
||||
# ['text2image', 'image2image', 'inpaint']
|
||||
```
|
||||
|
||||
Retrieve a specific workflow with `get_workflow` to inspect and debug a specific block that executes the workflow.
|
||||
|
||||
```py
|
||||
pipeline_blocks.get_workflow("inpaint")
|
||||
auto_blocks.get_execution_blocks("mask")
|
||||
```
|
||||
@@ -111,7 +111,7 @@ if __name__ == "__main__":
|
||||
Call `torchrun` to run the inference script and use the `--nproc_per_node` argument to set the number of GPUs to use.
|
||||
|
||||
```bash
|
||||
torchrun --nproc_per_node=2 run_distributed.py
|
||||
torchrun run_distributed.py --nproc_per_node=2
|
||||
```
|
||||
|
||||
## device_map
|
||||
|
||||
@@ -94,15 +94,9 @@ python scripts/convert_cosmos_to_diffusers.py \
|
||||
--transformer_type Cosmos-2.5-Transfer-General-2B \
|
||||
--transformer_ckpt_path $transformer_ckpt_path \
|
||||
--vae_type wan2.1 \
|
||||
--output_path converted/transfer/2b/general/depth/pipeline \
|
||||
--output_path converted/transfer/2b/general/depth \
|
||||
--save_pipeline
|
||||
|
||||
python scripts/convert_cosmos_to_diffusers.py \
|
||||
--transformer_type Cosmos-2.5-Transfer-General-2B \
|
||||
--transformer_ckpt_path $transformer_ckpt_path \
|
||||
--vae_type wan2.1 \
|
||||
--output_path converted/transfer/2b/general/depth/models
|
||||
|
||||
# edge
|
||||
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/edge/61f5694b-0ad5-4ecd-8ad7-c8545627d125_ema_bf16.pt
|
||||
|
||||
@@ -126,15 +120,9 @@ python scripts/convert_cosmos_to_diffusers.py \
|
||||
--transformer_type Cosmos-2.5-Transfer-General-2B \
|
||||
--transformer_ckpt_path $transformer_ckpt_path \
|
||||
--vae_type wan2.1 \
|
||||
--output_path converted/transfer/2b/general/blur/pipeline \
|
||||
--output_path converted/transfer/2b/general/blur \
|
||||
--save_pipeline
|
||||
|
||||
python scripts/convert_cosmos_to_diffusers.py \
|
||||
--transformer_type Cosmos-2.5-Transfer-General-2B \
|
||||
--transformer_ckpt_path $transformer_ckpt_path \
|
||||
--vae_type wan2.1 \
|
||||
--output_path converted/transfer/2b/general/blur/models
|
||||
|
||||
# seg
|
||||
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/seg/5136ef49-6d8d-42e8-8abf-7dac722a304a_ema_bf16.pt
|
||||
|
||||
@@ -142,14 +130,8 @@ python scripts/convert_cosmos_to_diffusers.py \
|
||||
--transformer_type Cosmos-2.5-Transfer-General-2B \
|
||||
--transformer_ckpt_path $transformer_ckpt_path \
|
||||
--vae_type wan2.1 \
|
||||
--output_path converted/transfer/2b/general/seg/pipeline \
|
||||
--output_path converted/transfer/2b/general/seg \
|
||||
--save_pipeline
|
||||
|
||||
python scripts/convert_cosmos_to_diffusers.py \
|
||||
--transformer_type Cosmos-2.5-Transfer-General-2B \
|
||||
--transformer_ckpt_path $transformer_ckpt_path \
|
||||
--vae_type wan2.1 \
|
||||
--output_path converted/transfer/2b/general/seg/models
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
@@ -856,7 +856,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
||||
)
|
||||
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("text_encoders.t5xxl.transformer.")}
|
||||
|
||||
has_diffb = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_", "lora_te1_")) for k in state_dict)
|
||||
has_diffb = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_")) for k in state_dict)
|
||||
if has_diffb:
|
||||
zero_status_diff_b = state_dict_all_zero(state_dict, ".diff_b")
|
||||
if zero_status_diff_b:
|
||||
@@ -895,7 +895,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
||||
state_dict = {
|
||||
_custom_replace(k, limit_substrings): v
|
||||
for k, v in state_dict.items()
|
||||
if k.startswith(("lora_unet_", "lora_te_", "lora_te1_"))
|
||||
if k.startswith(("lora_unet_", "lora_te_"))
|
||||
}
|
||||
|
||||
if any("text_projection" in k for k in state_dict):
|
||||
|
||||
@@ -62,8 +62,6 @@ _REQUIRED_FLEX_VERSION = "2.5.0"
|
||||
_REQUIRED_XLA_VERSION = "2.2"
|
||||
_REQUIRED_XFORMERS_VERSION = "0.0.29"
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
|
||||
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
|
||||
_CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION)
|
||||
@@ -75,18 +73,8 @@ _CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _
|
||||
|
||||
|
||||
if _CAN_USE_FLASH_ATTN:
|
||||
try:
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
|
||||
except (ImportError, OSError, RuntimeError) as e:
|
||||
# Handle ABI mismatch or other import failures gracefully.
|
||||
# This can happen when flash_attn was compiled against a different PyTorch version.
|
||||
logger.warning(f"flash_attn is installed but failed to import: {e}. Falling back to native PyTorch attention.")
|
||||
_CAN_USE_FLASH_ATTN = False
|
||||
flash_attn_func = None
|
||||
flash_attn_varlen_func = None
|
||||
_wrapped_flash_attn_backward = None
|
||||
_wrapped_flash_attn_forward = None
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
|
||||
else:
|
||||
flash_attn_func = None
|
||||
flash_attn_varlen_func = None
|
||||
@@ -95,47 +83,26 @@ else:
|
||||
|
||||
|
||||
if _CAN_USE_FLASH_ATTN_3:
|
||||
try:
|
||||
from flash_attn_interface import flash_attn_func as flash_attn_3_func
|
||||
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
|
||||
except (ImportError, OSError, RuntimeError) as e:
|
||||
logger.warning(f"flash_attn_3 failed to import: {e}. Falling back to native attention.")
|
||||
_CAN_USE_FLASH_ATTN_3 = False
|
||||
flash_attn_3_func = None
|
||||
flash_attn_3_varlen_func = None
|
||||
from flash_attn_interface import flash_attn_func as flash_attn_3_func
|
||||
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
|
||||
else:
|
||||
flash_attn_3_func = None
|
||||
flash_attn_3_varlen_func = None
|
||||
|
||||
if _CAN_USE_AITER_ATTN:
|
||||
try:
|
||||
from aiter import flash_attn_func as aiter_flash_attn_func
|
||||
except (ImportError, OSError, RuntimeError) as e:
|
||||
logger.warning(f"aiter failed to import: {e}. Falling back to native attention.")
|
||||
_CAN_USE_AITER_ATTN = False
|
||||
aiter_flash_attn_func = None
|
||||
from aiter import flash_attn_func as aiter_flash_attn_func
|
||||
else:
|
||||
aiter_flash_attn_func = None
|
||||
|
||||
if _CAN_USE_SAGE_ATTN:
|
||||
try:
|
||||
from sageattention import (
|
||||
sageattn,
|
||||
sageattn_qk_int8_pv_fp8_cuda,
|
||||
sageattn_qk_int8_pv_fp8_cuda_sm90,
|
||||
sageattn_qk_int8_pv_fp16_cuda,
|
||||
sageattn_qk_int8_pv_fp16_triton,
|
||||
sageattn_varlen,
|
||||
)
|
||||
except (ImportError, OSError, RuntimeError) as e:
|
||||
logger.warning(f"sageattention failed to import: {e}. Falling back to native attention.")
|
||||
_CAN_USE_SAGE_ATTN = False
|
||||
sageattn = None
|
||||
sageattn_qk_int8_pv_fp8_cuda = None
|
||||
sageattn_qk_int8_pv_fp8_cuda_sm90 = None
|
||||
sageattn_qk_int8_pv_fp16_cuda = None
|
||||
sageattn_qk_int8_pv_fp16_triton = None
|
||||
sageattn_varlen = None
|
||||
from sageattention import (
|
||||
sageattn,
|
||||
sageattn_qk_int8_pv_fp8_cuda,
|
||||
sageattn_qk_int8_pv_fp8_cuda_sm90,
|
||||
sageattn_qk_int8_pv_fp16_cuda,
|
||||
sageattn_qk_int8_pv_fp16_triton,
|
||||
sageattn_varlen,
|
||||
)
|
||||
else:
|
||||
sageattn = None
|
||||
sageattn_qk_int8_pv_fp16_cuda = None
|
||||
@@ -146,48 +113,26 @@ else:
|
||||
|
||||
|
||||
if _CAN_USE_FLEX_ATTN:
|
||||
try:
|
||||
# We cannot import the flex_attention function from the package directly because it is expected (from the
|
||||
# pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
|
||||
# compiled function.
|
||||
import torch.nn.attention.flex_attention as flex_attention
|
||||
except (ImportError, OSError, RuntimeError) as e:
|
||||
logger.warning(f"flex_attention failed to import: {e}. Falling back to native attention.")
|
||||
_CAN_USE_FLEX_ATTN = False
|
||||
flex_attention = None
|
||||
else:
|
||||
flex_attention = None
|
||||
# We cannot import the flex_attention function from the package directly because it is expected (from the
|
||||
# pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
|
||||
# compiled function.
|
||||
import torch.nn.attention.flex_attention as flex_attention
|
||||
|
||||
|
||||
if _CAN_USE_NPU_ATTN:
|
||||
try:
|
||||
from torch_npu import npu_fusion_attention
|
||||
except (ImportError, OSError, RuntimeError) as e:
|
||||
logger.warning(f"torch_npu failed to import: {e}. Falling back to native attention.")
|
||||
_CAN_USE_NPU_ATTN = False
|
||||
npu_fusion_attention = None
|
||||
from torch_npu import npu_fusion_attention
|
||||
else:
|
||||
npu_fusion_attention = None
|
||||
|
||||
|
||||
if _CAN_USE_XLA_ATTN:
|
||||
try:
|
||||
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
|
||||
except (ImportError, OSError, RuntimeError) as e:
|
||||
logger.warning(f"torch_xla failed to import: {e}. Falling back to native attention.")
|
||||
_CAN_USE_XLA_ATTN = False
|
||||
xla_flash_attention = None
|
||||
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
|
||||
else:
|
||||
xla_flash_attention = None
|
||||
|
||||
|
||||
if _CAN_USE_XFORMERS_ATTN:
|
||||
try:
|
||||
import xformers.ops as xops
|
||||
except (ImportError, OSError, RuntimeError) as e:
|
||||
logger.warning(f"xformers failed to import: {e}. Falling back to native attention.")
|
||||
_CAN_USE_XFORMERS_ATTN = False
|
||||
xops = None
|
||||
import xformers.ops as xops
|
||||
else:
|
||||
xops = None
|
||||
|
||||
@@ -213,6 +158,8 @@ else:
|
||||
_register_fake = register_fake_no_op
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# TODO(aryan): Add support for the following:
|
||||
# - Sage Attention++
|
||||
# - block sparse, radial and other attention methods
|
||||
@@ -329,11 +276,7 @@ class _HubKernelConfig:
|
||||
_HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = {
|
||||
# TODO: temporary revision for now. Remove when merged upstream into `main`.
|
||||
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-community/flash-attn3",
|
||||
function_attr="flash_attn_func",
|
||||
revision="fake-ops-return-probs",
|
||||
wrapped_forward_attr="flash_attn_interface._flash_attn_forward",
|
||||
wrapped_backward_attr="flash_attn_interface._flash_attn_backward",
|
||||
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
|
||||
),
|
||||
AttentionBackendName._FLASH_3_VARLEN_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-community/flash-attn3",
|
||||
@@ -733,7 +676,7 @@ def _wrapped_flash_attn_3(
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Hardcoded for now because pytorch does not support tuple/int type hints
|
||||
window_size = (-1, -1)
|
||||
result = flash_attn_3_func(
|
||||
out, lse, *_ = flash_attn_3_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
@@ -750,9 +693,7 @@ def _wrapped_flash_attn_3(
|
||||
pack_gqa=pack_gqa,
|
||||
deterministic=deterministic,
|
||||
sm_margin=sm_margin,
|
||||
return_attn_probs=True,
|
||||
)
|
||||
out, lse, *_ = result
|
||||
lse = lse.permute(0, 2, 1)
|
||||
return out, lse
|
||||
|
||||
@@ -1296,62 +1237,36 @@ def _flash_attention_3_hub_forward_op(
|
||||
if enable_gqa:
|
||||
raise ValueError("`enable_gqa` is not yet supported for flash-attn 3 hub kernels.")
|
||||
|
||||
config = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB]
|
||||
wrapped_forward_fn = config.wrapped_forward_fn
|
||||
if wrapped_forward_fn is None:
|
||||
raise RuntimeError(
|
||||
"Flash attention 3 hub kernels must expose `flash_attn_interface._flash_attn_forward` "
|
||||
"for context parallel execution."
|
||||
)
|
||||
|
||||
if scale is None:
|
||||
scale = query.shape[-1] ** (-0.5)
|
||||
|
||||
out, softmax_lse, *_ = wrapped_forward_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
None,
|
||||
None, # k_new, v_new
|
||||
None, # qv
|
||||
None, # out
|
||||
None,
|
||||
None,
|
||||
None, # cu_seqlens_q/k/k_new
|
||||
None,
|
||||
None, # seqused_q/k
|
||||
None,
|
||||
None, # max_seqlen_q/k
|
||||
None,
|
||||
None,
|
||||
None, # page_table, kv_batch_idx, leftpad_k
|
||||
None,
|
||||
None,
|
||||
None, # rotary_cos/sin, seqlens_rotary
|
||||
None,
|
||||
None,
|
||||
None, # q_descale, k_descale, v_descale
|
||||
scale,
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
window_size_left=window_size[0],
|
||||
window_size_right=window_size[1],
|
||||
attention_chunk=0,
|
||||
qv=None,
|
||||
q_descale=None,
|
||||
k_descale=None,
|
||||
v_descale=None,
|
||||
window_size=window_size,
|
||||
softcap=softcap,
|
||||
num_splits=num_splits,
|
||||
pack_gqa=pack_gqa,
|
||||
deterministic=deterministic,
|
||||
sm_margin=sm_margin,
|
||||
return_attn_probs=return_lse,
|
||||
)
|
||||
|
||||
lse = softmax_lse.permute(0, 2, 1).contiguous() if return_lse else None
|
||||
lse = None
|
||||
if return_lse:
|
||||
out, lse = out
|
||||
lse = lse.permute(0, 2, 1).contiguous()
|
||||
|
||||
if _save_ctx:
|
||||
ctx.save_for_backward(query, key, value, out, softmax_lse)
|
||||
ctx.save_for_backward(query, key, value)
|
||||
ctx.scale = scale
|
||||
ctx.is_causal = is_causal
|
||||
ctx.window_size = window_size
|
||||
ctx.softcap = softcap
|
||||
ctx.deterministic = deterministic
|
||||
ctx.sm_margin = sm_margin
|
||||
ctx._hub_kernel = func
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
@@ -1360,49 +1275,54 @@ def _flash_attention_3_hub_backward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
grad_out: torch.Tensor,
|
||||
*args,
|
||||
**kwargs,
|
||||
window_size: tuple[int, int] = (-1, -1),
|
||||
softcap: float = 0.0,
|
||||
num_splits: int = 1,
|
||||
pack_gqa: bool | None = None,
|
||||
deterministic: bool = False,
|
||||
sm_margin: int = 0,
|
||||
):
|
||||
config = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB]
|
||||
wrapped_backward_fn = config.wrapped_backward_fn
|
||||
if wrapped_backward_fn is None:
|
||||
raise RuntimeError(
|
||||
"Flash attention 3 hub kernels must expose `flash_attn_interface._flash_attn_backward` "
|
||||
"for context parallel execution."
|
||||
query, key, value = ctx.saved_tensors
|
||||
kernel_fn = ctx._hub_kernel
|
||||
# NOTE: Unlike the FA2 hub kernel, the FA3 hub kernel does not expose separate wrapped forward/backward
|
||||
# primitives (no `wrapped_forward_attr`/`wrapped_backward_attr` in its `_HubKernelConfig`). We
|
||||
# therefore rerun the forward pass under `torch.enable_grad()` and differentiate through it with
|
||||
# `torch.autograd.grad()`. This is a second forward pass during backward; it can be avoided once
|
||||
# the FA3 hub exposes a dedicated fused backward kernel (analogous to `_wrapped_flash_attn_backward`
|
||||
# in the FA2 hub), at which point this can be refactored to match `_flash_attention_hub_backward_op`.
|
||||
with torch.enable_grad():
|
||||
query_r = query.detach().requires_grad_(True)
|
||||
key_r = key.detach().requires_grad_(True)
|
||||
value_r = value.detach().requires_grad_(True)
|
||||
|
||||
out = kernel_fn(
|
||||
q=query_r,
|
||||
k=key_r,
|
||||
v=value_r,
|
||||
softmax_scale=ctx.scale,
|
||||
causal=ctx.is_causal,
|
||||
qv=None,
|
||||
q_descale=None,
|
||||
k_descale=None,
|
||||
v_descale=None,
|
||||
window_size=window_size,
|
||||
softcap=softcap,
|
||||
num_splits=num_splits,
|
||||
pack_gqa=pack_gqa,
|
||||
deterministic=deterministic,
|
||||
sm_margin=sm_margin,
|
||||
return_attn_probs=False,
|
||||
)
|
||||
if isinstance(out, tuple):
|
||||
out = out[0]
|
||||
|
||||
query, key, value, out, softmax_lse = ctx.saved_tensors
|
||||
grad_query = torch.empty_like(query)
|
||||
grad_key = torch.empty_like(key)
|
||||
grad_value = torch.empty_like(value)
|
||||
|
||||
wrapped_backward_fn(
|
||||
grad_out,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
softmax_lse,
|
||||
None,
|
||||
None, # cu_seqlens_q, cu_seqlens_k
|
||||
None,
|
||||
None, # seqused_q, seqused_k
|
||||
None,
|
||||
None, # max_seqlen_q, max_seqlen_k
|
||||
grad_query,
|
||||
grad_key,
|
||||
grad_value,
|
||||
ctx.scale,
|
||||
ctx.is_causal,
|
||||
ctx.window_size[0],
|
||||
ctx.window_size[1],
|
||||
ctx.softcap,
|
||||
ctx.deterministic,
|
||||
ctx.sm_margin,
|
||||
)
|
||||
|
||||
grad_query = grad_query[..., : grad_out.shape[-1]]
|
||||
grad_key = grad_key[..., : grad_out.shape[-1]]
|
||||
grad_value = grad_value[..., : grad_out.shape[-1]]
|
||||
grad_query, grad_key, grad_value = torch.autograd.grad(
|
||||
out,
|
||||
(query_r, key_r, value_r),
|
||||
grad_out,
|
||||
retain_graph=False,
|
||||
allow_unused=False,
|
||||
)
|
||||
|
||||
return grad_query, grad_key, grad_value
|
||||
|
||||
@@ -2703,7 +2623,7 @@ def _flash_varlen_attention_3(
|
||||
key_packed = torch.cat(key_valid, dim=0)
|
||||
value_packed = torch.cat(value_valid, dim=0)
|
||||
|
||||
result = flash_attn_3_varlen_func(
|
||||
out, lse, *_ = flash_attn_3_varlen_func(
|
||||
q=query_packed,
|
||||
k=key_packed,
|
||||
v=value_packed,
|
||||
@@ -2713,13 +2633,7 @@ def _flash_varlen_attention_3(
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
return_attn_probs=return_lse,
|
||||
)
|
||||
if isinstance(result, tuple):
|
||||
out, lse, *_ = result
|
||||
else:
|
||||
out = result
|
||||
lse = None
|
||||
out = out.unflatten(0, (batch_size, -1))
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
@@ -191,12 +191,7 @@ class CosmosControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
dim=1,
|
||||
)
|
||||
|
||||
if condition_mask is not None:
|
||||
control_hidden_states = torch.cat([control_hidden_states, condition_mask], dim=1)
|
||||
else:
|
||||
control_hidden_states = torch.cat(
|
||||
[control_hidden_states, torch.zeros_like(controls_latents[:, :1])], dim=1
|
||||
)
|
||||
control_hidden_states = torch.cat([control_hidden_states, torch.zeros_like(controls_latents[:, :1])], dim=1)
|
||||
|
||||
padding_mask_resized = transforms.functional.resize(
|
||||
padding_mask, list(control_hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
|
||||
|
||||
@@ -1633,14 +1633,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
blocks_class_name = self.default_blocks_name
|
||||
if blocks_class_name is not None:
|
||||
diffusers_module = importlib.import_module("diffusers")
|
||||
blocks_class = getattr(diffusers_module, blocks_class_name, None)
|
||||
# If the blocks_class is not found or is a base class (e.g. SequentialPipelineBlocks saved by from_blocks_dict) with empty block_classes
|
||||
# fall back to default_blocks_name
|
||||
if blocks_class is None or not blocks_class.block_classes:
|
||||
blocks_class_name = self.default_blocks_name
|
||||
blocks_class = getattr(diffusers_module, blocks_class_name)
|
||||
|
||||
if blocks_class is not None:
|
||||
blocks_class = getattr(diffusers_module, blocks_class_name)
|
||||
blocks = blocks_class()
|
||||
else:
|
||||
logger.warning(f"`blocks` is `None`, no default blocks class found for {self.__class__.__name__}")
|
||||
|
||||
@@ -17,6 +17,9 @@ from typing import Callable, Dict, List, Optional, Union
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms
|
||||
import torchvision.transforms.functional
|
||||
from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
@@ -51,13 +54,11 @@ else:
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def _maybe_pad_or_trim_video(video: torch.Tensor, num_frames: int):
|
||||
def _maybe_pad_video(video: torch.Tensor, num_frames: int):
|
||||
n_pad_frames = num_frames - video.shape[2]
|
||||
if n_pad_frames > 0:
|
||||
last_frame = video[:, :, -1:, :, :]
|
||||
video = torch.cat((video, last_frame.repeat(1, 1, n_pad_frames, 1, 1)), dim=2)
|
||||
elif num_frames < video.shape[2]:
|
||||
video = video[:, :, :num_frames, :, :]
|
||||
return video
|
||||
|
||||
|
||||
@@ -133,8 +134,8 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> controls = [Image.fromarray(x.numpy()) for x in controls.permute(1, 2, 3, 0)]
|
||||
>>> export_to_video(controls, "edge_controlled_video_edge.mp4", fps=30)
|
||||
|
||||
>>> # Transfer inference with controls.
|
||||
>>> video = pipe(
|
||||
... video=input_video[:num_frames],
|
||||
... controls=controls,
|
||||
... controls_conditioning_scale=1.0,
|
||||
... prompt=prompt,
|
||||
@@ -148,7 +149,7 @@ EXAMPLE_DOC_STRING = """
|
||||
|
||||
class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for Cosmos Transfer2.5, supporting auto-regressive inference.
|
||||
Pipeline for Cosmos Transfer2.5 base model.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
@@ -165,14 +166,12 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLWan`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
controlnet ([`CosmosControlNetModel`]):
|
||||
ControlNet used to condition generation on control inputs.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->controlnet->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
# We mark safety_checker as optional here to get around some test failures, but it is not really optional
|
||||
_optional_components = ["safety_checker"]
|
||||
_optional_components = ["safety_checker", "controlnet"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
@@ -182,8 +181,8 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
transformer: CosmosTransformer3DModel,
|
||||
vae: AutoencoderKLWan,
|
||||
scheduler: UniPCMultistepScheduler,
|
||||
controlnet: CosmosControlNetModel,
|
||||
safety_checker: Optional[CosmosSafetyChecker] = None,
|
||||
controlnet: Optional[CosmosControlNetModel],
|
||||
safety_checker: CosmosSafetyChecker = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -385,11 +384,10 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
num_frames_in: int = 93,
|
||||
num_frames_out: int = 93,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
num_cond_latent_frames: int = 0,
|
||||
dtype: torch.dtype | None = None,
|
||||
device: torch.device | None = None,
|
||||
generator: torch.Generator | list[torch.Generator] | None = None,
|
||||
latents: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
@@ -404,14 +402,10 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
W = width // self.vae_scale_factor_spatial
|
||||
shape = (B, C, T, H, W)
|
||||
|
||||
if latents is not None:
|
||||
if latents.shape[1:] != shape[1:]:
|
||||
raise ValueError(f"Unexpected `latents` shape, got {latents.shape}, expected {shape}.")
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
else:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
if num_frames_in == 0:
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
cond_mask = torch.zeros((B, 1, T, H, W), dtype=latents.dtype, device=latents.device)
|
||||
cond_indicator = torch.zeros((B, 1, T, 1, 1), dtype=latents.dtype, device=latents.device)
|
||||
|
||||
@@ -441,12 +435,16 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
latents_std = self.latents_std.to(device=device, dtype=dtype)
|
||||
cond_latents = (cond_latents - latents_mean) / latents_std
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
|
||||
padding_shape = (B, 1, T, H, W)
|
||||
ones_padding = latents.new_ones(padding_shape)
|
||||
zeros_padding = latents.new_zeros(padding_shape)
|
||||
|
||||
cond_indicator = latents.new_zeros(B, 1, latents.size(2), 1, 1)
|
||||
cond_indicator[:, :, 0:num_cond_latent_frames, :, :] = 1.0
|
||||
cond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1)
|
||||
cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding
|
||||
|
||||
return (
|
||||
@@ -456,7 +454,34 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
cond_indicator,
|
||||
)
|
||||
|
||||
# Modified from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs
|
||||
def _encode_controls(
|
||||
self,
|
||||
controls: Optional[torch.Tensor],
|
||||
height: int,
|
||||
width: int,
|
||||
num_frames: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
generator: torch.Generator | list[torch.Generator] | None,
|
||||
) -> Optional[torch.Tensor]:
|
||||
if controls is None:
|
||||
return None
|
||||
|
||||
control_video = self.video_processor.preprocess_video(controls, height, width)
|
||||
control_video = _maybe_pad_video(control_video, num_frames)
|
||||
|
||||
control_video = control_video.to(device=device, dtype=self.vae.dtype)
|
||||
control_latents = [
|
||||
retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator=generator) for vid in control_video
|
||||
]
|
||||
control_latents = torch.cat(control_latents, dim=0).to(dtype)
|
||||
|
||||
latents_mean = self.latents_mean.to(device=device, dtype=dtype)
|
||||
latents_std = self.latents_std.to(device=device, dtype=dtype)
|
||||
control_latents = (control_latents - latents_mean) / latents_std
|
||||
return control_latents
|
||||
|
||||
# Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
@@ -464,25 +489,9 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
width,
|
||||
prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
num_ar_conditional_frames=None,
|
||||
num_ar_latent_conditional_frames=None,
|
||||
num_frames_per_chunk=None,
|
||||
num_frames=None,
|
||||
conditional_frame_timestep=0.1,
|
||||
):
|
||||
if width <= 0 or height <= 0 or height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(
|
||||
f"`height` and `width` have to be divisible by 16 (& positive) but are {height} and {width}."
|
||||
)
|
||||
|
||||
if num_frames is not None and num_frames <= 0:
|
||||
raise ValueError(f"`num_frames` has to be a positive integer when provided but is {num_frames}.")
|
||||
|
||||
if conditional_frame_timestep < 0 or conditional_frame_timestep > 1:
|
||||
raise ValueError(
|
||||
"`conditional_frame_timestep` has to be a float in the [0, 1] interval but is "
|
||||
f"{conditional_frame_timestep}."
|
||||
)
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
@@ -503,46 +512,6 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if num_ar_latent_conditional_frames is not None and num_ar_conditional_frames is not None:
|
||||
raise ValueError(
|
||||
"Provide only one of `num_ar_conditional_frames` or `num_ar_latent_conditional_frames`, not both."
|
||||
)
|
||||
if num_ar_latent_conditional_frames is None and num_ar_conditional_frames is None:
|
||||
raise ValueError("Provide either `num_ar_conditional_frames` or `num_ar_latent_conditional_frames`.")
|
||||
if num_ar_latent_conditional_frames is not None and num_ar_latent_conditional_frames < 0:
|
||||
raise ValueError("`num_ar_latent_conditional_frames` must be >= 0.")
|
||||
if num_ar_conditional_frames is not None and num_ar_conditional_frames < 0:
|
||||
raise ValueError("`num_ar_conditional_frames` must be >= 0.")
|
||||
|
||||
if num_ar_latent_conditional_frames is not None:
|
||||
num_ar_conditional_frames = max(
|
||||
0, (num_ar_latent_conditional_frames - 1) * self.vae_scale_factor_temporal + 1
|
||||
)
|
||||
|
||||
min_chunk_len = self.vae_scale_factor_temporal + 1
|
||||
if num_frames_per_chunk < min_chunk_len:
|
||||
logger.warning(f"{num_frames_per_chunk=} must be larger than {min_chunk_len=}, setting to min_chunk_len")
|
||||
num_frames_per_chunk = min_chunk_len
|
||||
|
||||
max_frames_by_rope = None
|
||||
if getattr(self.transformer.config, "max_size", None) is not None:
|
||||
max_frames_by_rope = max(
|
||||
size // patch
|
||||
for size, patch in zip(self.transformer.config.max_size, self.transformer.config.patch_size)
|
||||
)
|
||||
if num_frames_per_chunk > max_frames_by_rope:
|
||||
raise ValueError(
|
||||
f"{num_frames_per_chunk=} is too large for RoPE setting ({max_frames_by_rope=}). "
|
||||
"Please reduce `num_frames_per_chunk`."
|
||||
)
|
||||
|
||||
if num_ar_conditional_frames >= num_frames_per_chunk:
|
||||
raise ValueError(
|
||||
f"{num_ar_conditional_frames=} must be smaller than {num_frames_per_chunk=} for chunked generation."
|
||||
)
|
||||
|
||||
return num_frames_per_chunk
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
@@ -567,22 +536,23 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
controls: PipelineImageInput | List[PipelineImageInput],
|
||||
controls_conditioning_scale: Union[float, List[float]] = 1.0,
|
||||
image: PipelineImageInput | None = None,
|
||||
video: List[PipelineImageInput] | None = None,
|
||||
prompt: Union[str, List[str]] | None = None,
|
||||
negative_prompt: Union[str, List[str]] = DEFAULT_NEGATIVE_PROMPT,
|
||||
height: int = 704,
|
||||
width: Optional[int] = None,
|
||||
num_frames: Optional[int] = None,
|
||||
num_frames_per_chunk: int = 93,
|
||||
width: int | None = None,
|
||||
num_frames: int = 93,
|
||||
num_inference_steps: int = 36,
|
||||
guidance_scale: float = 3.0,
|
||||
num_videos_per_prompt: int = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: torch.Generator | list[torch.Generator] | None = None,
|
||||
latents: torch.Tensor | None = None,
|
||||
controls: Optional[PipelineImageInput | List[PipelineImageInput]] = None,
|
||||
controls_conditioning_scale: float | list[float] = 1.0,
|
||||
prompt_embeds: torch.Tensor | None = None,
|
||||
negative_prompt_embeds: torch.Tensor | None = None,
|
||||
output_type: str = "pil",
|
||||
return_dict: bool = True,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
@@ -590,26 +560,24 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
conditional_frame_timestep: float = 0.1,
|
||||
num_ar_conditional_frames: Optional[int] = 1,
|
||||
num_ar_latent_conditional_frames: Optional[int] = None,
|
||||
):
|
||||
r"""
|
||||
`controls` drive the conditioning through ControlNet. Controls are assumed to be pre-processed, e.g. edge maps
|
||||
are pre-computed.
|
||||
The call function to the pipeline for generation. Supports three modes:
|
||||
|
||||
Setting `num_frames` will restrict the total number of frames output, if not provided or assigned to None
|
||||
(default) then the number of output frames will match the input `controls`.
|
||||
- **Text2World**: `image=None`, `video=None`, `prompt` provided. Generates a world clip.
|
||||
- **Image2World**: `image` provided, `video=None`, `prompt` provided. Conditions on a single frame.
|
||||
- **Video2World**: `video` provided, `image=None`, `prompt` provided. Conditions on an input clip.
|
||||
|
||||
Auto-regressive inference is supported and thus a sliding window of `num_frames_per_chunk` frames are used per
|
||||
denoising loop. In addition, when auto-regressive inference is performed, the previous
|
||||
`num_ar_latent_conditional_frames` or `num_ar_conditional_frames` are used to condition the following denoising
|
||||
inference loops.
|
||||
Set `num_frames=93` (default) to produce a world video, or `num_frames=1` to produce a single image frame (the
|
||||
above in "*2Image mode").
|
||||
|
||||
Outputs follow `output_type` (e.g., `"pil"` returns a list of `num_frames` PIL images per prompt).
|
||||
|
||||
Args:
|
||||
controls (`PipelineImageInput`, `List[PipelineImageInput]`):
|
||||
Control image or video input used by the ControlNet.
|
||||
controls_conditioning_scale (`float` or `List[float]`, *optional*, defaults to `1.0`):
|
||||
The scale factor(s) for the ControlNet outputs. A single float is broadcast to all control blocks.
|
||||
image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, *optional*):
|
||||
Optional single image for Image2World conditioning. Must be `None` when `video` is provided.
|
||||
video (`List[PIL.Image.Image]`, `np.ndarray`, `torch.Tensor`, *optional*):
|
||||
Optional input video for Video2World conditioning. Must be `None` when `image` is provided.
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide generation. Required unless `prompt_embeds` is supplied.
|
||||
height (`int`, defaults to `704`):
|
||||
@@ -617,10 +585,9 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image. If not provided, this will be determined based on the
|
||||
aspect ratio of the input and the provided height.
|
||||
num_frames (`int`, *optional*):
|
||||
Number of output frames. Defaults to `None` to output the same number of frames as the input
|
||||
`controls`.
|
||||
num_inference_steps (`int`, defaults to `36`):
|
||||
num_frames (`int`, defaults to `93`):
|
||||
Number of output frames. Use `93` for world (video) generation; set to `1` to return a single frame.
|
||||
num_inference_steps (`int`, defaults to `35`):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, defaults to `3.0`):
|
||||
@@ -634,9 +601,13 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs. Can be used to
|
||||
tweak the same generation with different prompts. If not provided, a latents tensor is generated by
|
||||
sampling using the supplied random `generator`.
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor is generated by sampling using the supplied random `generator`.
|
||||
controls (`PipelineImageInput`, `List[PipelineImageInput]`, *optional*):
|
||||
Control image or video input used by the ControlNet. If `None`, ControlNet is skipped.
|
||||
controls_conditioning_scale (`float` or `List[float]`, *optional*, defaults to `1.0`):
|
||||
The scale factor(s) for the ControlNet outputs. A single float is broadcast to all control blocks.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
@@ -659,18 +630,7 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
max_sequence_length (`int`, defaults to `512`):
|
||||
The maximum number of tokens in the prompt. If the prompt exceeds this length, it will be truncated. If
|
||||
the prompt is shorter than this length, it will be padded.
|
||||
num_ar_conditional_frames (`int`, *optional*, defaults to `1`):
|
||||
Number of frames to condition on subsequent inference loops in auto-regressive inference, i.e. for the
|
||||
second chunk and onwards. Only used if `num_ar_latent_conditional_frames` is `None`.
|
||||
|
||||
This is only used when auto-regressive inference is performed, i.e. when the number of frames in
|
||||
controls is > num_frames_per_chunk
|
||||
num_ar_latent_conditional_frames (`int`, *optional*):
|
||||
Number of latent frames to condition on subsequent inference loops in auto-regressive inference, i.e.
|
||||
for the second chunk and onwards. Only used if `num_ar_conditional_frames` is `None`.
|
||||
|
||||
This is only used when auto-regressive inference is performed, i.e. when the number of frames in
|
||||
controls is > num_frames_per_chunk
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
@@ -690,40 +650,21 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
if width is None:
|
||||
frame = controls[0] if isinstance(controls, list) else controls
|
||||
if isinstance(frame, list):
|
||||
frame = frame[0]
|
||||
if isinstance(frame, (torch.Tensor, np.ndarray)):
|
||||
if frame.ndim == 5:
|
||||
frame = frame[0, 0]
|
||||
elif frame.ndim == 4:
|
||||
frame = frame[0]
|
||||
frame = image or video[0] if image or video else None
|
||||
if frame is None and controls is not None:
|
||||
frame = controls[0] if isinstance(controls, list) else controls
|
||||
if isinstance(frame, (torch.Tensor, np.ndarray)) and len(frame.shape) == 4:
|
||||
frame = controls[0]
|
||||
|
||||
if isinstance(frame, PIL.Image.Image):
|
||||
if frame is None:
|
||||
width = int((height + 16) * (1280 / 720))
|
||||
elif isinstance(frame, PIL.Image.Image):
|
||||
width = int((height + 16) * (frame.width / frame.height))
|
||||
else:
|
||||
if frame.ndim != 3:
|
||||
raise ValueError("`controls` must contain 3D frames in CHW format.")
|
||||
width = int((height + 16) * (frame.shape[2] / frame.shape[1])) # NOTE: assuming C H W
|
||||
|
||||
num_frames_per_chunk = self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
num_ar_conditional_frames,
|
||||
num_ar_latent_conditional_frames,
|
||||
num_frames_per_chunk,
|
||||
num_frames,
|
||||
conditional_frame_timestep,
|
||||
)
|
||||
|
||||
if num_ar_latent_conditional_frames is not None:
|
||||
num_cond_latent_frames = num_ar_latent_conditional_frames
|
||||
num_ar_conditional_frames = max(0, (num_cond_latent_frames - 1) * self.vae_scale_factor_temporal + 1)
|
||||
else:
|
||||
num_cond_latent_frames = max(0, (num_ar_conditional_frames - 1) // self.vae_scale_factor_temporal + 1)
|
||||
# Check inputs. Raise error if not correct
|
||||
self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._current_timestep = None
|
||||
@@ -768,137 +709,102 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
vae_dtype = self.vae.dtype
|
||||
transformer_dtype = self.transformer.dtype
|
||||
|
||||
if getattr(self.transformer.config, "img_context_dim_in", None):
|
||||
img_context = torch.zeros(
|
||||
batch_size,
|
||||
self.transformer.config.img_context_num_tokens,
|
||||
self.transformer.config.img_context_dim_in,
|
||||
device=prompt_embeds.device,
|
||||
img_context = torch.zeros(
|
||||
batch_size,
|
||||
self.transformer.config.img_context_num_tokens,
|
||||
self.transformer.config.img_context_dim_in,
|
||||
device=prompt_embeds.device,
|
||||
dtype=transformer_dtype,
|
||||
)
|
||||
encoder_hidden_states = (prompt_embeds, img_context)
|
||||
neg_encoder_hidden_states = (negative_prompt_embeds, img_context)
|
||||
|
||||
num_frames_in = None
|
||||
if image is not None:
|
||||
if batch_size != 1:
|
||||
raise ValueError(f"batch_size must be 1 for image input (given {batch_size})")
|
||||
|
||||
image = torchvision.transforms.functional.to_tensor(image).unsqueeze(0)
|
||||
video = torch.cat([image, torch.zeros_like(image).repeat(num_frames - 1, 1, 1, 1)], dim=0)
|
||||
video = video.unsqueeze(0)
|
||||
num_frames_in = 1
|
||||
elif video is None:
|
||||
video = torch.zeros(batch_size, num_frames, 3, height, width, dtype=torch.uint8)
|
||||
num_frames_in = 0
|
||||
else:
|
||||
num_frames_in = len(video)
|
||||
|
||||
if batch_size != 1:
|
||||
raise ValueError(f"batch_size must be 1 for video input (given {batch_size})")
|
||||
|
||||
assert video is not None
|
||||
video = self.video_processor.preprocess_video(video, height, width)
|
||||
|
||||
# pad with last frame (for video2world)
|
||||
num_frames_out = num_frames
|
||||
video = _maybe_pad_video(video, num_frames_out)
|
||||
assert num_frames_in <= num_frames_out, f"expected ({num_frames_in=}) <= ({num_frames_out=})"
|
||||
|
||||
video = video.to(device=device, dtype=vae_dtype)
|
||||
|
||||
num_channels_latents = self.transformer.config.in_channels - 1
|
||||
latents, cond_latent, cond_mask, cond_indicator = self.prepare_latents(
|
||||
video=video,
|
||||
batch_size=batch_size * num_videos_per_prompt,
|
||||
num_channels_latents=num_channels_latents,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames_in=num_frames_in,
|
||||
num_frames_out=num_frames,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
)
|
||||
cond_timestep = torch.ones_like(cond_indicator) * conditional_frame_timestep
|
||||
cond_mask = cond_mask.to(transformer_dtype)
|
||||
|
||||
controls_latents = None
|
||||
if controls is not None:
|
||||
controls_latents = self._encode_controls(
|
||||
controls,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
dtype=transformer_dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
if num_videos_per_prompt > 1:
|
||||
img_context = img_context.repeat_interleave(num_videos_per_prompt, dim=0)
|
||||
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
|
||||
|
||||
encoder_hidden_states = (prompt_embeds, img_context)
|
||||
neg_encoder_hidden_states = (negative_prompt_embeds, img_context)
|
||||
else:
|
||||
encoder_hidden_states = prompt_embeds
|
||||
neg_encoder_hidden_states = negative_prompt_embeds
|
||||
# Denoising loop
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
self._num_timesteps = len(timesteps)
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
|
||||
control_video = self.video_processor.preprocess_video(controls, height, width)
|
||||
if control_video.shape[0] != batch_size:
|
||||
if control_video.shape[0] == 1:
|
||||
control_video = control_video.repeat(batch_size, 1, 1, 1, 1)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected controls batch size {batch_size} to match prompt batch size, but got {control_video.shape[0]}."
|
||||
gt_velocity = (latents - cond_latent) * cond_mask
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t.cpu().item()
|
||||
|
||||
# NOTE: assumes sigma(t) \in [0, 1]
|
||||
sigma_t = (
|
||||
torch.tensor(self.scheduler.sigmas[i].item())
|
||||
.unsqueeze(0)
|
||||
.to(device=device, dtype=transformer_dtype)
|
||||
)
|
||||
|
||||
num_frames_out = control_video.shape[2]
|
||||
if num_frames is not None:
|
||||
num_frames_out = min(num_frames_out, num_frames)
|
||||
|
||||
control_video = _maybe_pad_or_trim_video(control_video, num_frames_out)
|
||||
|
||||
# chunk information
|
||||
num_latent_frames_per_chunk = (num_frames_per_chunk - 1) // self.vae_scale_factor_temporal + 1
|
||||
chunk_stride = num_frames_per_chunk - num_ar_conditional_frames
|
||||
chunk_idxs = [
|
||||
(start_idx, min(start_idx + num_frames_per_chunk, num_frames_out))
|
||||
for start_idx in range(0, num_frames_out - num_ar_conditional_frames, chunk_stride)
|
||||
]
|
||||
|
||||
video_chunks = []
|
||||
latents_mean = self.latents_mean.to(dtype=vae_dtype, device=device)
|
||||
latents_std = self.latents_std.to(dtype=vae_dtype, device=device)
|
||||
|
||||
def decode_latents(latents):
|
||||
latents = latents * latents_std + latents_mean
|
||||
video = self.vae.decode(latents.to(dtype=self.vae.dtype, device=device), return_dict=False)[0]
|
||||
return video
|
||||
|
||||
latents_arg = latents
|
||||
initial_num_cond_latent_frames = 0
|
||||
latent_chunks = []
|
||||
num_chunks = len(chunk_idxs)
|
||||
total_steps = num_inference_steps * num_chunks
|
||||
with self.progress_bar(total=total_steps) as progress_bar:
|
||||
for chunk_idx, (start_idx, end_idx) in enumerate(chunk_idxs):
|
||||
if chunk_idx == 0:
|
||||
prev_output = torch.zeros((batch_size, num_frames_per_chunk, 3, height, width), dtype=vae_dtype)
|
||||
prev_output = self.video_processor.preprocess_video(prev_output, height, width)
|
||||
else:
|
||||
prev_output = video_chunks[-1].clone()
|
||||
if num_ar_conditional_frames > 0:
|
||||
prev_output[:, :, :num_ar_conditional_frames] = prev_output[:, :, -num_ar_conditional_frames:]
|
||||
prev_output[:, :, num_ar_conditional_frames:] = -1 # -1 == 0 in processed video space
|
||||
else:
|
||||
prev_output.fill_(-1)
|
||||
|
||||
chunk_video = prev_output.to(device=device, dtype=vae_dtype)
|
||||
chunk_video = _maybe_pad_or_trim_video(chunk_video, num_frames_per_chunk)
|
||||
latents, cond_latent, cond_mask, cond_indicator = self.prepare_latents(
|
||||
video=chunk_video,
|
||||
batch_size=batch_size * num_videos_per_prompt,
|
||||
num_channels_latents=self.transformer.config.in_channels - 1,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames_in=chunk_video.shape[2],
|
||||
num_frames_out=num_frames_per_chunk,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
generator=generator,
|
||||
num_cond_latent_frames=initial_num_cond_latent_frames
|
||||
if chunk_idx == 0
|
||||
else num_cond_latent_frames,
|
||||
latents=latents_arg,
|
||||
)
|
||||
cond_mask = cond_mask.to(transformer_dtype)
|
||||
cond_timestep = torch.ones_like(cond_indicator) * conditional_frame_timestep
|
||||
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
|
||||
|
||||
chunk_control_video = control_video[:, :, start_idx:end_idx, ...].to(
|
||||
device=device, dtype=self.vae.dtype
|
||||
)
|
||||
chunk_control_video = _maybe_pad_or_trim_video(chunk_control_video, num_frames_per_chunk)
|
||||
if isinstance(generator, list):
|
||||
controls_latents = [
|
||||
retrieve_latents(self.vae.encode(chunk_control_video[i].unsqueeze(0)), generator=generator[i])
|
||||
for i in range(chunk_control_video.shape[0])
|
||||
]
|
||||
else:
|
||||
controls_latents = [
|
||||
retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator=generator)
|
||||
for vid in chunk_control_video
|
||||
]
|
||||
controls_latents = torch.cat(controls_latents, dim=0).to(transformer_dtype)
|
||||
|
||||
controls_latents = (controls_latents - latents_mean) / latents_std
|
||||
|
||||
# Denoising loop
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
gt_velocity = (latents - cond_latent) * cond_mask
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t.cpu().item()
|
||||
|
||||
# NOTE: assumes sigma(t) \in [0, 1]
|
||||
sigma_t = (
|
||||
torch.tensor(self.scheduler.sigmas[i].item())
|
||||
.unsqueeze(0)
|
||||
.to(device=device, dtype=transformer_dtype)
|
||||
)
|
||||
|
||||
in_latents = cond_mask * cond_latent + (1 - cond_mask) * latents
|
||||
in_latents = in_latents.to(transformer_dtype)
|
||||
in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator) * sigma_t
|
||||
in_latents = cond_mask * cond_latent + (1 - cond_mask) * latents
|
||||
in_latents = in_latents.to(transformer_dtype)
|
||||
in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator) * sigma_t
|
||||
control_blocks = None
|
||||
if controls_latents is not None and self.controlnet is not None:
|
||||
control_output = self.controlnet(
|
||||
controls_latents=controls_latents,
|
||||
latents=in_latents,
|
||||
@@ -911,18 +817,20 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
)
|
||||
control_blocks = control_output[0]
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=in_latents,
|
||||
timestep=in_timestep,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
block_controlnet_hidden_states=control_blocks,
|
||||
condition_mask=cond_mask,
|
||||
padding_mask=padding_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = gt_velocity + noise_pred * (1 - cond_mask)
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=in_latents,
|
||||
timestep=in_timestep,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
block_controlnet_hidden_states=control_blocks,
|
||||
condition_mask=cond_mask,
|
||||
padding_mask=padding_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = gt_velocity + noise_pred * (1 - cond_mask)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
if self.do_classifier_free_guidance:
|
||||
control_blocks = None
|
||||
if controls_latents is not None and self.controlnet is not None:
|
||||
control_output = self.controlnet(
|
||||
controls_latents=controls_latents,
|
||||
latents=in_latents,
|
||||
@@ -935,50 +843,46 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
)
|
||||
control_blocks = control_output[0]
|
||||
|
||||
noise_pred_neg = self.transformer(
|
||||
hidden_states=in_latents,
|
||||
timestep=in_timestep,
|
||||
encoder_hidden_states=neg_encoder_hidden_states, # NOTE: negative prompt
|
||||
block_controlnet_hidden_states=control_blocks,
|
||||
condition_mask=cond_mask,
|
||||
padding_mask=padding_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
# NOTE: replace velocity (noise_pred_neg) with gt_velocity for conditioning inputs only
|
||||
noise_pred_neg = gt_velocity + noise_pred_neg * (1 - cond_mask)
|
||||
noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_neg)
|
||||
noise_pred_neg = self.transformer(
|
||||
hidden_states=in_latents,
|
||||
timestep=in_timestep,
|
||||
encoder_hidden_states=neg_encoder_hidden_states, # NOTE: negative prompt
|
||||
block_controlnet_hidden_states=control_blocks,
|
||||
condition_mask=cond_mask,
|
||||
padding_mask=padding_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
# NOTE: replace velocity (noise_pred_neg) with gt_velocity for conditioning inputs only
|
||||
noise_pred_neg = gt_velocity + noise_pred_neg * (1 - cond_mask)
|
||||
noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_neg)
|
||||
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
# call the callback, if provided
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
if i == total_steps - 1 or ((i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
video_chunks.append(decode_latents(latents).detach().cpu())
|
||||
latent_chunks.append(latents.detach().cpu())
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
video_chunks = [
|
||||
chunk[:, :, num_ar_conditional_frames:, ...] if chunk_idx != 0 else chunk
|
||||
for chunk_idx, chunk in enumerate(video_chunks)
|
||||
]
|
||||
video = torch.cat(video_chunks, dim=2)
|
||||
video = video[:, :, :num_frames_out, ...]
|
||||
latents_mean = self.latents_mean.to(latents.device, latents.dtype)
|
||||
latents_std = self.latents_std.to(latents.device, latents.dtype)
|
||||
latents = latents * latents_std + latents_mean
|
||||
video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
|
||||
video = self._match_num_frames(video, num_frames)
|
||||
|
||||
assert self.safety_checker is not None
|
||||
self.safety_checker.to(device)
|
||||
@@ -995,13 +899,7 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
video = torch.from_numpy(video).permute(0, 4, 1, 2, 3)
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
else:
|
||||
latent_T = (num_frames_out - 1) // self.vae_scale_factor_temporal + 1
|
||||
latent_chunks = [
|
||||
chunk[:, :, num_cond_latent_frames:, ...] if chunk_idx != 0 else chunk
|
||||
for chunk_idx, chunk in enumerate(latent_chunks)
|
||||
]
|
||||
video = torch.cat(latent_chunks, dim=2)
|
||||
video = video[:, :, :latent_T, ...]
|
||||
video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
@@ -1010,3 +908,19 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
return (video,)
|
||||
|
||||
return CosmosPipelineOutput(frames=video)
|
||||
|
||||
def _match_num_frames(self, video: torch.Tensor, target_num_frames: int) -> torch.Tensor:
|
||||
if target_num_frames <= 0 or video.shape[2] == target_num_frames:
|
||||
return video
|
||||
|
||||
frames_per_latent = max(self.vae_scale_factor_temporal, 1)
|
||||
video = torch.repeat_interleave(video, repeats=frames_per_latent, dim=2)
|
||||
|
||||
current_frames = video.shape[2]
|
||||
if current_frames < target_num_frames:
|
||||
pad = video[:, :, -1:, :, :].repeat(1, 1, target_num_frames - current_frames, 1, 1)
|
||||
video = torch.cat([video, pad], dim=2)
|
||||
elif current_frames > target_num_frames:
|
||||
video = video[:, :, :target_num_frames]
|
||||
|
||||
return video
|
||||
|
||||
@@ -699,13 +699,9 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
mask_shape = (batch_size, 1, num_frames, height, width)
|
||||
|
||||
if latents is not None:
|
||||
conditioning_mask = latents.new_zeros(mask_shape)
|
||||
conditioning_mask[:, :, 0] = 1.0
|
||||
if latents.ndim == 5:
|
||||
# conditioning_mask needs to the same shape as latents in two stages generation.
|
||||
batch_size, _, num_frames, height, width = latents.shape
|
||||
mask_shape = (batch_size, 1, num_frames, height, width)
|
||||
conditioning_mask = latents.new_zeros(mask_shape)
|
||||
conditioning_mask[:, :, 0] = 1.0
|
||||
|
||||
latents = self._normalize_latents(
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
@@ -714,9 +710,6 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
latents = self._pack_latents(
|
||||
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
|
||||
)
|
||||
else:
|
||||
conditioning_mask = latents.new_zeros(mask_shape)
|
||||
conditioning_mask[:, :, 0] = 1.0
|
||||
conditioning_mask = self._pack_latents(
|
||||
conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
|
||||
).squeeze(-1)
|
||||
|
||||
@@ -276,7 +276,7 @@ class ZImagePipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMix
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 0
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def joint_attention_kwargs(self):
|
||||
|
||||
@@ -131,26 +131,6 @@ class CosmosControlNetModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
self.assertIsInstance(output[0], list)
|
||||
self.assertEqual(len(output[0]), init_dict["n_controlnet_blocks"])
|
||||
|
||||
def test_condition_mask_changes_output(self):
|
||||
"""Test that condition mask affects control outputs."""
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs_no_mask = dict(inputs_dict)
|
||||
inputs_no_mask["condition_mask"] = torch.zeros_like(inputs_dict["condition_mask"])
|
||||
|
||||
with torch.no_grad():
|
||||
output_no_mask = model(**inputs_no_mask)
|
||||
output_with_mask = model(**inputs_dict)
|
||||
|
||||
self.assertEqual(len(output_no_mask.control_block_samples), len(output_with_mask.control_block_samples))
|
||||
for no_mask_tensor, with_mask_tensor in zip(
|
||||
output_no_mask.control_block_samples, output_with_mask.control_block_samples
|
||||
):
|
||||
self.assertFalse(torch.allclose(no_mask_tensor, with_mask_tensor))
|
||||
|
||||
def test_conditioning_scale_single(self):
|
||||
"""Test that a single conditioning scale is broadcast to all blocks."""
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from .attention import AttentionTesterMixin
|
||||
from .attention import AttentionBackendTesterMixin, AttentionTesterMixin
|
||||
from .cache import (
|
||||
CacheTesterMixin,
|
||||
FasterCacheConfigMixin,
|
||||
@@ -38,6 +38,7 @@ from .training import TrainingTesterMixin
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AttentionBackendTesterMixin",
|
||||
"AttentionTesterMixin",
|
||||
"BaseModelTesterConfig",
|
||||
"BitsAndBytesCompileTesterMixin",
|
||||
|
||||
@@ -14,22 +14,105 @@
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers.models.attention import AttentionModuleMixin
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor,
|
||||
from diffusers.models.attention_dispatch import AttentionBackendName, _AttentionBackendRegistry, attention_backend
|
||||
from diffusers.models.attention_processor import AttnProcessor
|
||||
from diffusers.utils import is_kernels_available, is_torch_version
|
||||
|
||||
from ...testing_utils import assert_tensors_close, backend_empty_cache, is_attention, torch_device
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-level backend parameter sets for AttentionBackendTesterMixin
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_CUDA_AVAILABLE = torch.cuda.is_available()
|
||||
_KERNELS_AVAILABLE = is_kernels_available()
|
||||
|
||||
_PARAM_NATIVE = pytest.param(AttentionBackendName.NATIVE, id="native")
|
||||
|
||||
_PARAM_NATIVE_CUDNN = pytest.param(
|
||||
AttentionBackendName._NATIVE_CUDNN,
|
||||
id="native_cudnn",
|
||||
marks=pytest.mark.skipif(
|
||||
not _CUDA_AVAILABLE,
|
||||
reason="CUDA is required for _native_cudnn backend.",
|
||||
),
|
||||
)
|
||||
|
||||
from ...testing_utils import (
|
||||
assert_tensors_close,
|
||||
backend_empty_cache,
|
||||
is_attention,
|
||||
torch_device,
|
||||
_PARAM_FLASH_HUB = pytest.param(
|
||||
AttentionBackendName.FLASH_HUB,
|
||||
id="flash_hub",
|
||||
marks=[
|
||||
pytest.mark.skipif(not _CUDA_AVAILABLE, reason="CUDA is required for flash_hub backend."),
|
||||
pytest.mark.skipif(
|
||||
not _KERNELS_AVAILABLE,
|
||||
reason="`kernels` package is required for flash_hub backend. Install with `pip install kernels`.",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
_PARAM_FLASH_3_HUB = pytest.param(
|
||||
AttentionBackendName._FLASH_3_HUB,
|
||||
id="flash_3_hub",
|
||||
marks=[
|
||||
pytest.mark.skipif(not _CUDA_AVAILABLE, reason="CUDA is required for _flash_3_hub backend."),
|
||||
pytest.mark.skipif(
|
||||
not _KERNELS_AVAILABLE,
|
||||
reason="`kernels` package is required for _flash_3_hub backend. Install with `pip install kernels`.",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# All backends under test.
|
||||
_ALL_BACKEND_PARAMS = [_PARAM_NATIVE, _PARAM_NATIVE_CUDNN, _PARAM_FLASH_HUB, _PARAM_FLASH_3_HUB]
|
||||
|
||||
# Backends that only accept bf16/fp16 inputs; models and inputs must be cast before running them.
|
||||
_BF16_REQUIRED_BACKENDS = {
|
||||
AttentionBackendName._NATIVE_CUDNN,
|
||||
AttentionBackendName.FLASH_HUB,
|
||||
AttentionBackendName._FLASH_3_HUB,
|
||||
}
|
||||
|
||||
# Backends that perform non-deterministic operations and therefore cannot run when
|
||||
# torch.use_deterministic_algorithms(True) is active (e.g. after enable_full_determinism()).
|
||||
_NON_DETERMINISTIC_BACKENDS = {AttentionBackendName._NATIVE_CUDNN}
|
||||
|
||||
|
||||
def _maybe_cast_to_bf16(backend, model, inputs_dict):
|
||||
"""Cast model and floating-point inputs to bfloat16 when the backend requires it."""
|
||||
if backend not in _BF16_REQUIRED_BACKENDS:
|
||||
return model, inputs_dict
|
||||
model = model.to(dtype=torch.bfloat16)
|
||||
inputs_dict = {
|
||||
k: v.to(dtype=torch.bfloat16) if isinstance(v, torch.Tensor) and v.is_floating_point() else v
|
||||
for k, v in inputs_dict.items()
|
||||
}
|
||||
return model, inputs_dict
|
||||
|
||||
|
||||
def _skip_if_backend_requires_nondeterminism(backend):
|
||||
"""Skip at runtime when torch.use_deterministic_algorithms(True) blocks the backend.
|
||||
|
||||
This check is intentionally deferred to test execution time because
|
||||
enable_full_determinism() is typically called at module level in test files *after*
|
||||
the module-level pytest.param() objects in this file have already been evaluated,
|
||||
making it impossible to catch via a collection-time skipif condition.
|
||||
"""
|
||||
if backend in _NON_DETERMINISTIC_BACKENDS and torch.are_deterministic_algorithms_enabled():
|
||||
pytest.skip(
|
||||
f"Backend '{backend.value}' performs non-deterministic operations and cannot run "
|
||||
f"while `torch.use_deterministic_algorithms(True)` is active."
|
||||
)
|
||||
|
||||
|
||||
@is_attention
|
||||
class AttentionTesterMixin:
|
||||
@@ -39,7 +122,6 @@ class AttentionTesterMixin:
|
||||
Tests functionality from AttentionModuleMixin including:
|
||||
- Attention processor management (set/get)
|
||||
- QKV projection fusion/unfusion
|
||||
- Attention backends (XFormers, NPU, etc.)
|
||||
|
||||
Expected from config mixin:
|
||||
- model_class: The model class to test
|
||||
@@ -179,3 +261,208 @@ class AttentionTesterMixin:
|
||||
model.set_attn_processor(wrong_processors)
|
||||
|
||||
assert "number of processors" in str(exc_info.value).lower(), "Error should mention processor count mismatch"
|
||||
|
||||
|
||||
@is_attention
|
||||
class AttentionBackendTesterMixin:
|
||||
"""
|
||||
Mixin class for testing attention backends on models. Following things are tested:
|
||||
|
||||
1. Backends can be set with the `attention_backend` context manager and with
|
||||
`set_attention_backend()` method.
|
||||
2. SDPA outputs don't deviate too much from backend outputs.
|
||||
3. Backend works with (regional) compilation.
|
||||
4. Backends can be restored.
|
||||
|
||||
Tests the backends using the model provided by the host test class. The backends to test
|
||||
are defined in `_ALL_BACKEND_PARAMS`.
|
||||
|
||||
Expected from the host test class:
|
||||
- model_class: The model class to instantiate.
|
||||
|
||||
Expected methods from the host test class:
|
||||
- get_init_dict(): Returns dict of kwargs to construct the model.
|
||||
- get_dummy_inputs(): Returns dict of inputs for the model's forward pass.
|
||||
|
||||
Pytest mark: attention
|
||||
Use `pytest -m "not attention"` to skip these tests.
|
||||
"""
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Tolerance attributes — override in host class to loosen/tighten checks.
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
# test_output_close_to_native: alternate backends (flash, cuDNN) may
|
||||
# accumulate small numerical errors vs the reference PyTorch SDPA kernel.
|
||||
backend_vs_native_atol: float = 1e-2
|
||||
backend_vs_native_rtol: float = 1e-2
|
||||
|
||||
# test_compile: regional compilation introduces the same kind of numerical
|
||||
# error as the non-compiled backend path, so the same loose tolerance applies.
|
||||
compile_vs_native_atol: float = 1e-2
|
||||
compile_vs_native_rtol: float = 1e-2
|
||||
|
||||
def setup_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def teardown_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@torch.no_grad()
|
||||
@pytest.mark.parametrize("backend", _ALL_BACKEND_PARAMS)
|
||||
def test_set_attention_backend_matches_context_manager(self, backend):
|
||||
"""set_attention_backend() and the attention_backend() context manager must yield identical outputs."""
|
||||
_skip_if_backend_requires_nondeterminism(backend)
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
model, inputs_dict = _maybe_cast_to_bf16(backend, model, inputs_dict)
|
||||
|
||||
with attention_backend(backend):
|
||||
ctx_output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
initial_registry_backend, _ = _AttentionBackendRegistry.get_active_backend()
|
||||
|
||||
try:
|
||||
model.set_attention_backend(backend.value)
|
||||
except Exception as e:
|
||||
logger.warning("Skipping test for backend '%s': %s", backend.value, e)
|
||||
pytest.skip(str(e))
|
||||
|
||||
try:
|
||||
set_output = model(**inputs_dict, return_dict=False)[0]
|
||||
finally:
|
||||
model.reset_attention_backend()
|
||||
_AttentionBackendRegistry.set_active_backend(initial_registry_backend)
|
||||
|
||||
assert_tensors_close(
|
||||
set_output,
|
||||
ctx_output,
|
||||
atol=0,
|
||||
rtol=0,
|
||||
msg=(
|
||||
f"Output from model.set_attention_backend('{backend.value}') should be identical "
|
||||
f"to the output from `with attention_backend('{backend.value}'):`."
|
||||
),
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
@pytest.mark.parametrize("backend", _ALL_BACKEND_PARAMS)
|
||||
def test_output_close_to_native(self, backend):
|
||||
"""All backends should produce model output numerically close to the native SDPA reference."""
|
||||
_skip_if_backend_requires_nondeterminism(backend)
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
model, inputs_dict = _maybe_cast_to_bf16(backend, model, inputs_dict)
|
||||
|
||||
with attention_backend(AttentionBackendName.NATIVE):
|
||||
native_output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
initial_registry_backend, _ = _AttentionBackendRegistry.get_active_backend()
|
||||
|
||||
try:
|
||||
model.set_attention_backend(backend.value)
|
||||
except Exception as e:
|
||||
logger.warning("Skipping test for backend '%s': %s", backend.value, e)
|
||||
pytest.skip(str(e))
|
||||
|
||||
try:
|
||||
backend_output = model(**inputs_dict, return_dict=False)[0]
|
||||
finally:
|
||||
model.reset_attention_backend()
|
||||
_AttentionBackendRegistry.set_active_backend(initial_registry_backend)
|
||||
|
||||
assert_tensors_close(
|
||||
backend_output,
|
||||
native_output,
|
||||
atol=self.backend_vs_native_atol,
|
||||
rtol=self.backend_vs_native_rtol,
|
||||
msg=f"Output from {backend} should be numerically close to native SDPA.",
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("backend", _ALL_BACKEND_PARAMS)
|
||||
def test_context_manager_switches_and_restores_backend(self, backend):
|
||||
"""attention_backend() should activate the requested backend and restore the previous one on exit."""
|
||||
initial_backend, _ = _AttentionBackendRegistry.get_active_backend()
|
||||
|
||||
with attention_backend(backend):
|
||||
active_backend, _ = _AttentionBackendRegistry.get_active_backend()
|
||||
assert active_backend == backend, (
|
||||
f"Backend should be {backend} inside the context manager, got {active_backend}."
|
||||
)
|
||||
|
||||
restored_backend, _ = _AttentionBackendRegistry.get_active_backend()
|
||||
assert restored_backend == initial_backend, (
|
||||
f"Backend should be restored to {initial_backend} after exiting the context manager, "
|
||||
f"got {restored_backend}."
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("backend", _ALL_BACKEND_PARAMS)
|
||||
def test_compile(self, backend):
|
||||
"""
|
||||
`torch.compile` tests checking for recompilation, graph breaks, forward can run, etc.
|
||||
For speed, we use regional compilation here (`model.compile_repeated_blocks()`
|
||||
as opposed to `model.compile`).
|
||||
"""
|
||||
_skip_if_backend_requires_nondeterminism(backend)
|
||||
if getattr(self.model_class, "_repeated_blocks", None) is None:
|
||||
pytest.skip("Skipping tests as regional compilation is not supported.")
|
||||
|
||||
if backend == AttentionBackendName.NATIVE and not is_torch_version(">=", "2.9.0"):
|
||||
pytest.xfail(
|
||||
"test_compile with the native backend requires torch >= 2.9.0 for stable "
|
||||
"fullgraph compilation with error_on_recompile=True."
|
||||
)
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
model, inputs_dict = _maybe_cast_to_bf16(backend, model, inputs_dict)
|
||||
|
||||
with torch.no_grad(), attention_backend(AttentionBackendName.NATIVE):
|
||||
native_output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
initial_registry_backend, _ = _AttentionBackendRegistry.get_active_backend()
|
||||
|
||||
try:
|
||||
model.set_attention_backend(backend.value)
|
||||
except Exception as e:
|
||||
logger.warning("Skipping test for backend '%s': %s", backend.value, e)
|
||||
pytest.skip(str(e))
|
||||
|
||||
try:
|
||||
model.compile_repeated_blocks(fullgraph=True)
|
||||
torch.compiler.reset()
|
||||
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(error_on_recompile=True),
|
||||
):
|
||||
with torch.no_grad():
|
||||
compile_output = model(**inputs_dict, return_dict=False)[0]
|
||||
model(**inputs_dict, return_dict=False)
|
||||
finally:
|
||||
model.reset_attention_backend()
|
||||
_AttentionBackendRegistry.set_active_backend(initial_registry_backend)
|
||||
|
||||
assert_tensors_close(
|
||||
compile_output,
|
||||
native_output,
|
||||
atol=self.compile_vs_native_atol,
|
||||
rtol=self.compile_vs_native_rtol,
|
||||
msg=f"Compiled output with backend '{backend.value}' should be numerically close to eager native SDPA.",
|
||||
)
|
||||
|
||||
@@ -25,6 +25,7 @@ from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
AttentionBackendTesterMixin,
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
BitsAndBytesCompileTesterMixin,
|
||||
@@ -224,6 +225,10 @@ class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterM
|
||||
"""Attention processor tests for Flux Transformer."""
|
||||
|
||||
|
||||
class TestFluxTransformerAttentionBackend(FluxTransformerTesterConfig, AttentionBackendTesterMixin):
|
||||
"""Attention backend tests for Flux Transformer."""
|
||||
|
||||
|
||||
class TestFluxTransformerContextParallel(FluxTransformerTesterConfig, ContextParallelTesterMixin):
|
||||
"""Context Parallel inference tests for Flux Transformer"""
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
@@ -128,16 +129,18 @@ class TestFluxImg2ImgModularPipelineFast(ModularPipelineTesterMixin):
|
||||
|
||||
return inputs
|
||||
|
||||
def test_save_from_pretrained(self, tmp_path):
|
||||
def test_save_from_pretrained(self):
|
||||
pipes = []
|
||||
base_pipe = self.get_pipeline().to(torch_device)
|
||||
pipes.append(base_pipe)
|
||||
|
||||
base_pipe.save_pretrained(tmp_path)
|
||||
pipe = ModularPipeline.from_pretrained(tmp_path).to(torch_device)
|
||||
pipe.load_components(torch_dtype=torch.float32)
|
||||
pipe.to(torch_device)
|
||||
pipe.image_processor = VaeImageProcessor(vae_scale_factor=2)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
base_pipe.save_pretrained(tmpdirname)
|
||||
|
||||
pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
|
||||
pipe.load_components(torch_dtype=torch.float32)
|
||||
pipe.to(torch_device)
|
||||
pipe.image_processor = VaeImageProcessor(vae_scale_factor=2)
|
||||
|
||||
pipes.append(pipe)
|
||||
|
||||
@@ -209,16 +212,18 @@ class TestFluxKontextModularPipelineFast(ModularPipelineTesterMixin):
|
||||
|
||||
return inputs
|
||||
|
||||
def test_save_from_pretrained(self, tmp_path):
|
||||
def test_save_from_pretrained(self):
|
||||
pipes = []
|
||||
base_pipe = self.get_pipeline().to(torch_device)
|
||||
pipes.append(base_pipe)
|
||||
|
||||
base_pipe.save_pretrained(tmp_path)
|
||||
pipe = ModularPipeline.from_pretrained(tmp_path).to(torch_device)
|
||||
pipe.load_components(torch_dtype=torch.float32)
|
||||
pipe.to(torch_device)
|
||||
pipe.image_processor = VaeImageProcessor(vae_scale_factor=2)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
base_pipe.save_pretrained(tmpdirname)
|
||||
|
||||
pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
|
||||
pipe.load_components(torch_dtype=torch.float32)
|
||||
pipe.to(torch_device)
|
||||
pipe.image_processor = VaeImageProcessor(vae_scale_factor=2)
|
||||
|
||||
pipes.append(pipe)
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import gc
|
||||
import json
|
||||
import tempfile
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
@@ -328,15 +328,16 @@ class ModularPipelineTesterMixin:
|
||||
|
||||
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
|
||||
def test_save_from_pretrained(self, tmp_path):
|
||||
def test_save_from_pretrained(self):
|
||||
pipes = []
|
||||
base_pipe = self.get_pipeline().to(torch_device)
|
||||
pipes.append(base_pipe)
|
||||
|
||||
base_pipe.save_pretrained(tmp_path)
|
||||
pipe = ModularPipeline.from_pretrained(tmp_path).to(torch_device)
|
||||
pipe.load_components(torch_dtype=torch.float32)
|
||||
pipe.to(torch_device)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
base_pipe.save_pretrained(tmpdirname)
|
||||
pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
|
||||
pipe.load_components(torch_dtype=torch.float32)
|
||||
pipe.to(torch_device)
|
||||
|
||||
pipes.append(pipe)
|
||||
|
||||
@@ -348,32 +349,6 @@ class ModularPipelineTesterMixin:
|
||||
|
||||
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
|
||||
def test_modular_index_consistency(self, tmp_path):
|
||||
pipe = self.get_pipeline()
|
||||
components_spec = pipe._component_specs
|
||||
components = sorted(components_spec.keys())
|
||||
|
||||
pipe.save_pretrained(tmp_path)
|
||||
index_file = tmp_path / "modular_model_index.json"
|
||||
assert index_file.exists()
|
||||
|
||||
with open(index_file) as f:
|
||||
index_contents = json.load(f)
|
||||
|
||||
compulsory_keys = {"_blocks_class_name", "_class_name", "_diffusers_version"}
|
||||
for k in compulsory_keys:
|
||||
assert k in index_contents
|
||||
|
||||
to_check_attrs = {"pretrained_model_name_or_path", "revision", "subfolder"}
|
||||
for component in components:
|
||||
spec = components_spec[component]
|
||||
for attr in to_check_attrs:
|
||||
if getattr(spec, "pretrained_model_name_or_path", None) is not None:
|
||||
for attr in to_check_attrs:
|
||||
assert component in index_contents, f"{component} should be present in index but isn't."
|
||||
attr_value_from_index = index_contents[component][2][attr]
|
||||
assert getattr(spec, attr) == attr_value_from_index
|
||||
|
||||
def test_workflow_map(self):
|
||||
blocks = self.pipeline_blocks_class()
|
||||
if blocks._workflow_map is None:
|
||||
@@ -724,27 +699,3 @@ class TestLoadComponentsSkipBehavior:
|
||||
|
||||
# Verify test_component was not loaded
|
||||
assert not hasattr(pipe, "test_component") or pipe.test_component is None
|
||||
|
||||
|
||||
class TestModularPipelineInitFallback:
|
||||
"""Test that ModularPipeline.__init__ falls back to default_blocks_name when
|
||||
_blocks_class_name is a base class (e.g. SequentialPipelineBlocks saved by from_blocks_dict)."""
|
||||
|
||||
def test_init_fallback_when_blocks_class_name_is_base_class(self, tmp_path):
|
||||
# 1. Load pipeline and get a workflow (returns a base SequentialPipelineBlocks)
|
||||
pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
|
||||
t2i_blocks = pipe.blocks.get_workflow("text2image")
|
||||
assert t2i_blocks.__class__.__name__ == "SequentialPipelineBlocks"
|
||||
|
||||
# 2. Use init_pipeline to create a new pipeline from the workflow blocks
|
||||
t2i_pipe = t2i_blocks.init_pipeline("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
|
||||
|
||||
# 3. Save and reload — the saved config will have _blocks_class_name="SequentialPipelineBlocks"
|
||||
save_dir = str(tmp_path / "pipeline")
|
||||
t2i_pipe.save_pretrained(save_dir)
|
||||
loaded_pipe = ModularPipeline.from_pretrained(save_dir)
|
||||
|
||||
# 4. Verify it fell back to default_blocks_name and has correct blocks
|
||||
assert loaded_pipe.__class__.__name__ == pipe.__class__.__name__
|
||||
assert loaded_pipe._blocks.__class__.__name__ == pipe._blocks.__class__.__name__
|
||||
assert len(loaded_pipe._blocks.sub_blocks) == len(pipe._blocks.sub_blocks)
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from collections import deque
|
||||
from typing import List
|
||||
|
||||
@@ -152,24 +153,25 @@ class TestModularCustomBlocks:
|
||||
output_prompt = output.values["output_prompt"]
|
||||
assert output_prompt.startswith("Modular diffusers + ")
|
||||
|
||||
def test_custom_block_saving_loading(self, tmp_path):
|
||||
def test_custom_block_saving_loading(self):
|
||||
custom_block = DummyCustomBlockSimple()
|
||||
|
||||
custom_block.save_pretrained(tmp_path)
|
||||
assert any("modular_config.json" in k for k in os.listdir(tmp_path))
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
custom_block.save_pretrained(tmpdir)
|
||||
assert any("modular_config.json" in k for k in os.listdir(tmpdir))
|
||||
|
||||
with open(os.path.join(tmp_path, "modular_config.json"), "r") as f:
|
||||
config = json.load(f)
|
||||
auto_map = config["auto_map"]
|
||||
assert auto_map == {"ModularPipelineBlocks": "test_modular_pipelines_custom_blocks.DummyCustomBlockSimple"}
|
||||
with open(os.path.join(tmpdir, "modular_config.json"), "r") as f:
|
||||
config = json.load(f)
|
||||
auto_map = config["auto_map"]
|
||||
assert auto_map == {"ModularPipelineBlocks": "test_modular_pipelines_custom_blocks.DummyCustomBlockSimple"}
|
||||
|
||||
# For now, the Python script that implements the custom block has to be manually pushed to the Hub.
|
||||
# This is why, we have to separately save the Python script here.
|
||||
code_path = os.path.join(tmp_path, "test_modular_pipelines_custom_blocks.py")
|
||||
with open(code_path, "w") as f:
|
||||
f.write(CODE_STR)
|
||||
# For now, the Python script that implements the custom block has to be manually pushed to the Hub.
|
||||
# This is why, we have to separately save the Python script here.
|
||||
code_path = os.path.join(tmpdir, "test_modular_pipelines_custom_blocks.py")
|
||||
with open(code_path, "w") as f:
|
||||
f.write(CODE_STR)
|
||||
|
||||
loaded_custom_block = ModularPipelineBlocks.from_pretrained(tmp_path, trust_remote_code=True)
|
||||
loaded_custom_block = ModularPipelineBlocks.from_pretrained(tmpdir, trust_remote_code=True)
|
||||
|
||||
pipe = loaded_custom_block.init_pipeline()
|
||||
prompt = "Diffusers is nice"
|
||||
|
||||
@@ -1,163 +0,0 @@
|
||||
"""
|
||||
This test suite exists for the maintainers currently. It's not run in our CI at the moment.
|
||||
|
||||
Once attention backends become more mature, we can consider including this in our CI.
|
||||
|
||||
To run this test suite:
|
||||
|
||||
```bash
|
||||
export RUN_ATTENTION_BACKEND_TESTS=yes
|
||||
|
||||
pytest tests/others/test_attention_backends.py
|
||||
```
|
||||
|
||||
Tests were conducted on an H100 with PyTorch 2.8.0 (CUDA 12.9). Slices for the compilation tests in
|
||||
"native" variants were obtained with a torch nightly version (2.10.0.dev20250924+cu128).
|
||||
|
||||
Tests for aiter backend were conducted and slices for the aiter backend tests collected on a MI355X
|
||||
with torch 2025-09-25 nightly version (ad2f7315ca66b42497047bb7951f696b50f1e81b) and
|
||||
aiter 0.1.5.post4.dev20+ga25e55e79.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.getenv("RUN_ATTENTION_BACKEND_TESTS", "false") == "false", reason="Feature not mature enough."
|
||||
)
|
||||
from diffusers import FluxPipeline # noqa: E402
|
||||
from diffusers.utils import is_torch_version # noqa: E402
|
||||
|
||||
|
||||
# fmt: off
|
||||
FORWARD_CASES = [
|
||||
(
|
||||
"flash_hub",
|
||||
torch.tensor([0.0820, 0.0859, 0.0918, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2461, 0.2422, 0.2695], dtype=torch.bfloat16)
|
||||
),
|
||||
(
|
||||
"_flash_3_hub",
|
||||
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0977, 0.0996, 0.1016, 0.1016, 0.2188, 0.2246, 0.2344, 0.2480, 0.2539, 0.2480, 0.2441, 0.2715], dtype=torch.bfloat16),
|
||||
),
|
||||
(
|
||||
"native",
|
||||
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2480, 0.2461, 0.2734], dtype=torch.bfloat16)
|
||||
),
|
||||
(
|
||||
"_native_cudnn",
|
||||
torch.tensor([0.0781, 0.0840, 0.0879, 0.0957, 0.0898, 0.0957, 0.0957, 0.0977, 0.2168, 0.2246, 0.2324, 0.2500, 0.2539, 0.2480, 0.2441, 0.2695], dtype=torch.bfloat16),
|
||||
),
|
||||
(
|
||||
"aiter",
|
||||
torch.tensor([0.0781, 0.0820, 0.0879, 0.0957, 0.0898, 0.0938, 0.0957, 0.0957, 0.2285, 0.2363, 0.2461, 0.2637, 0.2695, 0.2617, 0.2617, 0.2891], dtype=torch.bfloat16),
|
||||
)
|
||||
]
|
||||
|
||||
COMPILE_CASES = [
|
||||
(
|
||||
"flash_hub",
|
||||
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2324, 0.2422, 0.2539, 0.2734, 0.2832, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16),
|
||||
True
|
||||
),
|
||||
(
|
||||
"_flash_3_hub",
|
||||
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0625, 0.0605, 0.2344, 0.2461, 0.2578, 0.2734, 0.2852, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16),
|
||||
True,
|
||||
),
|
||||
(
|
||||
"native",
|
||||
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0605, 0.0605, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2773, 0.3066], dtype=torch.bfloat16),
|
||||
True,
|
||||
),
|
||||
(
|
||||
"_native_cudnn",
|
||||
torch.tensor([0.0410, 0.0410, 0.0430, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2793, 0.3086], dtype=torch.bfloat16),
|
||||
True,
|
||||
),
|
||||
(
|
||||
"aiter",
|
||||
torch.tensor([0.0391, 0.0391, 0.0430, 0.0488, 0.0469, 0.0566, 0.0586, 0.0566, 0.2402, 0.2539, 0.2637, 0.2812, 0.2930, 0.2910, 0.2891, 0.3164], dtype=torch.bfloat16),
|
||||
True,
|
||||
)
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
INFER_KW = {
|
||||
"prompt": "dance doggo dance",
|
||||
"height": 256,
|
||||
"width": 256,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 3.5,
|
||||
"max_sequence_length": 128,
|
||||
"output_type": "pt",
|
||||
}
|
||||
|
||||
|
||||
def _backend_is_probably_supported(pipe, name: str):
|
||||
try:
|
||||
pipe.transformer.set_attention_backend(name)
|
||||
return pipe, True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _check_if_slices_match(output, expected_slice):
|
||||
img = output.images.detach().cpu()
|
||||
generated_slice = img.flatten()
|
||||
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
|
||||
assert torch.allclose(generated_slice, expected_slice, atol=1e-4)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def device():
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is required for these tests.")
|
||||
return torch.device("cuda:0")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def pipe(device):
|
||||
repo_id = "black-forest-labs/FLUX.1-dev"
|
||||
pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16).to(device)
|
||||
pipe.set_progress_bar_config(disable=True)
|
||||
return pipe
|
||||
|
||||
|
||||
@pytest.mark.parametrize("backend_name,expected_slice", FORWARD_CASES, ids=[c[0] for c in FORWARD_CASES])
|
||||
def test_forward(pipe, backend_name, expected_slice):
|
||||
out = _backend_is_probably_supported(pipe, backend_name)
|
||||
if isinstance(out, bool):
|
||||
pytest.xfail(f"Backend '{backend_name}' not supported in this environment.")
|
||||
|
||||
modified_pipe = out[0]
|
||||
out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0))
|
||||
_check_if_slices_match(out, expected_slice)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"backend_name,expected_slice,error_on_recompile",
|
||||
COMPILE_CASES,
|
||||
ids=[c[0] for c in COMPILE_CASES],
|
||||
)
|
||||
def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recompile):
|
||||
if "native" in backend_name and error_on_recompile and not is_torch_version(">=", "2.9.0"):
|
||||
pytest.xfail(f"Test with {backend_name=} is compatible with a higher version of torch.")
|
||||
|
||||
out = _backend_is_probably_supported(pipe, backend_name)
|
||||
if isinstance(out, bool):
|
||||
pytest.xfail(f"Backend '{backend_name}' not supported in this environment.")
|
||||
|
||||
modified_pipe = out[0]
|
||||
modified_pipe.transformer.compile(fullgraph=True)
|
||||
|
||||
torch.compiler.reset()
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(error_on_recompile=error_on_recompile),
|
||||
):
|
||||
out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0))
|
||||
|
||||
_check_if_slices_match(out, expected_slice)
|
||||
@@ -55,7 +55,7 @@ class Cosmos2_5_TransferWrapper(Cosmos2_5_TransferPipeline):
|
||||
class Cosmos2_5_TransferPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = Cosmos2_5_TransferWrapper
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"controls"})
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
@@ -176,19 +176,15 @@ class Cosmos2_5_TransferPipelineFastTests(PipelineTesterMixin, unittest.TestCase
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
controls_generator = torch.Generator(device="cpu").manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "dance monkey",
|
||||
"negative_prompt": "bad quality",
|
||||
"controls": [torch.randn(3, 32, 32, generator=controls_generator) for _ in range(5)],
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 3.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"num_frames": 3,
|
||||
"num_frames_per_chunk": 16,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
@@ -216,56 +212,6 @@ class Cosmos2_5_TransferPipelineFastTests(PipelineTesterMixin, unittest.TestCase
|
||||
self.assertEqual(generated_video.shape, (3, 3, 32, 32))
|
||||
self.assertTrue(torch.isfinite(generated_video).all())
|
||||
|
||||
def test_inference_autoregressive_multi_chunk(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["num_frames"] = 5
|
||||
inputs["num_frames_per_chunk"] = 3
|
||||
inputs["num_ar_conditional_frames"] = 1
|
||||
|
||||
video = pipe(**inputs).frames
|
||||
generated_video = video[0]
|
||||
self.assertEqual(generated_video.shape, (5, 3, 32, 32))
|
||||
self.assertTrue(torch.isfinite(generated_video).all())
|
||||
|
||||
def test_inference_autoregressive_multi_chunk_no_condition_frames(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["num_frames"] = 5
|
||||
inputs["num_frames_per_chunk"] = 3
|
||||
inputs["num_ar_conditional_frames"] = 0
|
||||
|
||||
video = pipe(**inputs).frames
|
||||
generated_video = video[0]
|
||||
self.assertEqual(generated_video.shape, (5, 3, 32, 32))
|
||||
self.assertTrue(torch.isfinite(generated_video).all())
|
||||
|
||||
def test_num_frames_per_chunk_above_rope_raises(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["num_frames_per_chunk"] = 17
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "too large for RoPE setting"):
|
||||
pipe(**inputs)
|
||||
|
||||
def test_inference_with_controls(self):
|
||||
"""Test inference with control inputs (ControlNet)."""
|
||||
device = "cpu"
|
||||
@@ -276,13 +222,13 @@ class Cosmos2_5_TransferPipelineFastTests(PipelineTesterMixin, unittest.TestCase
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["controls"] = [torch.randn(3, 32, 32) for _ in range(5)] # list of 5 frames (C, H, W)
|
||||
# Add control video input - should be a video tensor
|
||||
inputs["controls"] = [torch.randn(3, 3, 32, 32)] # num_frames, channels, height, width
|
||||
inputs["controls_conditioning_scale"] = 1.0
|
||||
inputs["num_frames"] = None
|
||||
|
||||
video = pipe(**inputs).frames
|
||||
generated_video = video[0]
|
||||
self.assertEqual(generated_video.shape, (5, 3, 32, 32))
|
||||
self.assertEqual(generated_video.shape, (3, 3, 32, 32))
|
||||
self.assertTrue(torch.isfinite(generated_video).all())
|
||||
|
||||
def test_callback_inputs(self):
|
||||
|
||||
@@ -24,8 +24,7 @@ from diffusers import (
|
||||
LTX2ImageToVideoPipeline,
|
||||
LTX2VideoTransformer3DModel,
|
||||
)
|
||||
from diffusers.pipelines.ltx2 import LTX2LatentUpsamplePipeline, LTX2TextConnectors
|
||||
from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
|
||||
from diffusers.pipelines.ltx2 import LTX2TextConnectors
|
||||
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
|
||||
|
||||
from ...testing_utils import enable_full_determinism
|
||||
@@ -175,15 +174,6 @@ class LTX2ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
return components
|
||||
|
||||
def get_dummy_upsample_component(self, in_channels=4, mid_channels=32, num_blocks_per_stage=1):
|
||||
upsampler = LTX2LatentUpsamplerModel(
|
||||
in_channels=in_channels,
|
||||
mid_channels=mid_channels,
|
||||
num_blocks_per_stage=num_blocks_per_stage,
|
||||
)
|
||||
|
||||
return upsampler
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
@@ -297,60 +287,5 @@ class LTX2ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4)
|
||||
assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_two_stages_inference_with_upsampler(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["output_type"] = "latent"
|
||||
first_stage_output = pipe(**inputs)
|
||||
video_latent = first_stage_output.frames
|
||||
audio_latent = first_stage_output.audio
|
||||
|
||||
self.assertEqual(video_latent.shape, (1, 4, 3, 16, 16))
|
||||
self.assertEqual(audio_latent.shape, (1, 2, 5, 2))
|
||||
self.assertEqual(audio_latent.shape[1], components["vocoder"].config.out_channels)
|
||||
|
||||
upsampler = self.get_dummy_upsample_component(in_channels=video_latent.shape[1])
|
||||
upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=upsampler)
|
||||
upscaled_video_latent = upsample_pipe(latents=video_latent, output_type="latent", return_dict=False)[0]
|
||||
self.assertEqual(upscaled_video_latent.shape, (1, 4, 3, 32, 32))
|
||||
|
||||
inputs["latents"] = upscaled_video_latent
|
||||
inputs["audio_latents"] = audio_latent
|
||||
inputs["output_type"] = "pt"
|
||||
second_stage_output = pipe(**inputs)
|
||||
video = second_stage_output.frames
|
||||
audio = second_stage_output.audio
|
||||
|
||||
self.assertEqual(video.shape, (1, 5, 3, 64, 64))
|
||||
self.assertEqual(audio.shape[0], 1)
|
||||
self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels)
|
||||
|
||||
# fmt: off
|
||||
expected_video_slice = torch.tensor(
|
||||
[
|
||||
0.4497, 0.6757, 0.4219, 0.7686, 0.4525, 0.6483, 0.3969, 0.7404, 0.3541, 0.3039, 0.4592, 0.3521, 0.3665, 0.2785, 0.3336, 0.3079
|
||||
]
|
||||
)
|
||||
expected_audio_slice = torch.tensor(
|
||||
[
|
||||
0.0271, 0.0492, 0.1249, 0.1126, 0.1661, 0.1060, 0.1717, 0.0944, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
video = video.flatten()
|
||||
audio = audio.flatten()
|
||||
generated_video_slice = torch.cat([video[:8], video[-8:]])
|
||||
generated_audio_slice = torch.cat([audio[:8], audio[-8:]])
|
||||
|
||||
assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4)
|
||||
assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=2e-2)
|
||||
|
||||
@@ -72,6 +72,7 @@ OPTIONAL_TESTERS = [
|
||||
# Other testers
|
||||
("SingleFileTesterMixin", "single_file"),
|
||||
("IPAdapterTesterMixin", "ip_adapter"),
|
||||
("AttentionBackendTesterMixin", "attention_backends"),
|
||||
]
|
||||
|
||||
|
||||
@@ -530,6 +531,7 @@ def main():
|
||||
"faster_cache",
|
||||
"single_file",
|
||||
"ip_adapter",
|
||||
"attention_backends",
|
||||
"all",
|
||||
],
|
||||
help="Optional testers to include",
|
||||
|
||||
Reference in New Issue
Block a user