mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 12:34:13 +08:00
Hunyuanvideo15 (#12696)
* add --------- Co-authored-by: yiyi@huggingface.co <yiyi@ip-26-0-161-123.ec2.internal> Co-authored-by: yiyi@huggingface.co <yiyi@ip-26-0-160-103.ec2.internal> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
@@ -359,6 +359,8 @@
|
||||
title: HunyuanDiT2DModel
|
||||
- local: api/models/hunyuanimage_transformer_2d
|
||||
title: HunyuanImageTransformer2DModel
|
||||
- local: api/models/hunyuan_video15_transformer_3d
|
||||
title: HunyuanVideo15Transformer3DModel
|
||||
- local: api/models/hunyuan_video_transformer_3d
|
||||
title: HunyuanVideoTransformer3DModel
|
||||
- local: api/models/latte_transformer3d
|
||||
@@ -433,6 +435,8 @@
|
||||
title: AutoencoderKLHunyuanImageRefiner
|
||||
- local: api/models/autoencoder_kl_hunyuan_video
|
||||
title: AutoencoderKLHunyuanVideo
|
||||
- local: api/models/autoencoder_kl_hunyuan_video15
|
||||
title: AutoencoderKLHunyuanVideo15
|
||||
- local: api/models/autoencoderkl_ltx_video
|
||||
title: AutoencoderKLLTXVideo
|
||||
- local: api/models/autoencoderkl_magvit
|
||||
@@ -652,6 +656,8 @@
|
||||
title: Framepack
|
||||
- local: api/pipelines/hunyuan_video
|
||||
title: HunyuanVideo
|
||||
- local: api/pipelines/hunyuan_video15
|
||||
title: HunyuanVideo1.5
|
||||
- local: api/pipelines/i2vgenxl
|
||||
title: I2VGen-XL
|
||||
- local: api/pipelines/kandinsky5_video
|
||||
|
||||
36
docs/source/en/api/models/autoencoder_kl_hunyuan_video15.md
Normal file
36
docs/source/en/api/models/autoencoder_kl_hunyuan_video15.md
Normal file
@@ -0,0 +1,36 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# AutoencoderKLHunyuanVideo15
|
||||
|
||||
The 3D variational autoencoder (VAE) model with KL loss used in [HunyuanVideo1.5](https://github.com/Tencent/HunyuanVideo1-1.5) by Tencent.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import AutoencoderKLHunyuanVideo15
|
||||
|
||||
vae = AutoencoderKLHunyuanVideo15.from_pretrained("hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_t2v", subfolder="vae", torch_dtype=torch.float32)
|
||||
|
||||
# make sure to enable tiling to avoid OOM
|
||||
vae.enable_tiling()
|
||||
```
|
||||
|
||||
## AutoencoderKLHunyuanVideo15
|
||||
|
||||
[[autodoc]] AutoencoderKLHunyuanVideo15
|
||||
- decode
|
||||
- encode
|
||||
- all
|
||||
|
||||
## DecoderOutput
|
||||
|
||||
[[autodoc]] models.autoencoders.vae.DecoderOutput
|
||||
30
docs/source/en/api/models/hunyuan_video15_transformer_3d.md
Normal file
30
docs/source/en/api/models/hunyuan_video15_transformer_3d.md
Normal file
@@ -0,0 +1,30 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# HunyuanVideo15Transformer3DModel
|
||||
|
||||
A Diffusion Transformer model for 3D video-like data used in [HunyuanVideo1.5](https://github.com/Tencent/HunyuanVideo1-1.5).
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import HunyuanVideo15Transformer3DModel
|
||||
|
||||
transformer = HunyuanVideo15Transformer3DModel.from_pretrained("hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_t2v" subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## HunyuanVideo15Transformer3DModel
|
||||
|
||||
[[autodoc]] HunyuanVideo15Transformer3DModel
|
||||
|
||||
## Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
|
||||
120
docs/source/en/api/pipelines/hunyuan_video15.md
Normal file
120
docs/source/en/api/pipelines/hunyuan_video15.md
Normal file
@@ -0,0 +1,120 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
|
||||
# HunyuanVideo-1.5
|
||||
|
||||
HunyuanVideo-1.5 is a lightweight yet powerful video generation model that achieves state-of-the-art visual quality and motion coherence with only 8.3 billion parameters, enabling efficient inference on consumer-grade GPUs. This achievement is built upon several key components, including meticulous data curation, an advanced DiT architecture with selective and sliding tile attention (SSTA), enhanced bilingual understanding through glyph-aware text encoding, progressive pre-training and post-training, and an efficient video super-resolution network. Leveraging these designs, we developed a unified framework capable of high-quality text-to-video and image-to-video generation across multiple durations and resolutions. Extensive experiments demonstrate that this compact and proficient model establishes a new state-of-the-art among open-source models.
|
||||
|
||||
You can find all the original HunyuanVideo checkpoints under the [Tencent](https://huggingface.co/tencent) organization.
|
||||
|
||||
> [!TIP]
|
||||
> Click on the HunyuanVideo models in the right sidebar for more examples of video generation tasks.
|
||||
>
|
||||
> The examples below use a checkpoint from [hunyuanvideo-community](https://huggingface.co/hunyuanvideo-community) because the weights are stored in a layout compatible with Diffusers.
|
||||
|
||||
The example below demonstrates how to generate a video optimized for memory or inference speed.
|
||||
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="memory">
|
||||
|
||||
Refer to the [Reduce memory usage](../../optimization/memory) guide for more details about the various memory saving techniques.
|
||||
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import AutoModel, HunyuanVideo15Pipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
|
||||
pipeline = HunyuanVideo15Pipeline.from_pretrained(
|
||||
"HunyuanVideo-1.5-Diffusers-480p_t2v",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
# model-offloading and tiling
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.vae.enable_tiling()
|
||||
|
||||
prompt = "A fluffy teddy bear sits on a bed of soft pillows surrounded by children's toys."
|
||||
video = pipeline(prompt=prompt, num_frames=61, num_inference_steps=30).frames[0]
|
||||
export_to_video(video, "output.mp4", fps=15)
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- HunyuanVideo1.5 use attention masks with variable-length sequences. For best performance, we recommend using an attention backend that handles padding efficiently.
|
||||
|
||||
- **H100/H800:** `_flash_3_hub` or `_flash_varlen_3`
|
||||
- **A100/A800/RTX 4090:** `flash_hub` or `flash_varlen`
|
||||
- **Other GPUs:** `sage_hub`
|
||||
|
||||
Refer to the [Attention backends](../../optimization/attention_backends) guide for more details about using a different backend.
|
||||
|
||||
|
||||
```py
|
||||
pipe.transformer.set_attention_backend("flash_hub") # or your preferred backend
|
||||
```
|
||||
|
||||
- [`HunyuanVideo15Pipeline`] use guider and does not take `guidance_scale` parameter at runtime.
|
||||
|
||||
You can check the default guider configuration using `pipe.guider`:
|
||||
|
||||
```py
|
||||
>>> pipe.guider
|
||||
ClassifierFreeGuidance {
|
||||
"_class_name": "ClassifierFreeGuidance",
|
||||
"_diffusers_version": "0.36.0.dev0",
|
||||
"enabled": true,
|
||||
"guidance_rescale": 0.0,
|
||||
"guidance_scale": 6.0,
|
||||
"start": 0.0,
|
||||
"stop": 1.0,
|
||||
"use_original_formulation": false
|
||||
}
|
||||
|
||||
State:
|
||||
step: None
|
||||
num_inference_steps: None
|
||||
timestep: None
|
||||
count_prepared: 0
|
||||
enabled: True
|
||||
num_conditions: 2
|
||||
```
|
||||
|
||||
To update guider configuration, you can run `pipe.guider = pipe.guider.new(...)`
|
||||
|
||||
```py
|
||||
pipe.guider = pipe.guider.new(guidance_scale=5.0)
|
||||
```
|
||||
|
||||
Read more on Guider [here](../../modular_diffusers/guiders).
|
||||
|
||||
|
||||
|
||||
## HunyuanVideo15Pipeline
|
||||
|
||||
[[autodoc]] HunyuanVideo15Pipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## HunyuanVideo15ImageToVideoPipeline
|
||||
|
||||
[[autodoc]] HunyuanVideo15ImageToVideoPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## HunyuanVideo15PipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.hunyuan_video1_5.pipeline_output.HunyuanVideo15PipelineOutput
|
||||
850
scripts/convert_hunyuan_video1_5_to_diffusers.py
Normal file
850
scripts/convert_hunyuan_video1_5_to_diffusers.py
Normal file
@@ -0,0 +1,850 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
from safetensors.torch import load_file
|
||||
from transformers import (
|
||||
AutoModel,
|
||||
AutoTokenizer,
|
||||
SiglipImageProcessor,
|
||||
SiglipVisionModel,
|
||||
T5EncoderModel,
|
||||
)
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLHunyuanVideo15,
|
||||
ClassifierFreeGuidance,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
HunyuanVideo15ImageToVideoPipeline,
|
||||
HunyuanVideo15Pipeline,
|
||||
HunyuanVideo15Transformer3DModel,
|
||||
)
|
||||
|
||||
|
||||
# to convert only transformer
|
||||
"""
|
||||
python scripts/convert_hunyuan_video1_5_to_diffusers.py \
|
||||
--original_state_dict_repo_id tencent/HunyuanVideo-1.5\
|
||||
--output_path /fsx/yiyi/HunyuanVideo-1.5-Diffusers/transformer\
|
||||
--transformer_type 480p_t2v
|
||||
"""
|
||||
|
||||
# to convert full pipeline
|
||||
"""
|
||||
python scripts/convert_hunyuan_video1_5_to_diffusers.py \
|
||||
--original_state_dict_repo_id tencent/HunyuanVideo-1.5\
|
||||
--output_path /fsx/yiyi/HunyuanVideo-1.5-Diffusers \
|
||||
--save_pipeline \
|
||||
--byt5_path /fsx/yiyi/hy15/text_encoder/Glyph-SDXL-v2\
|
||||
--transformer_type 480p_t2v
|
||||
"""
|
||||
|
||||
|
||||
TRANSFORMER_CONFIGS = {
|
||||
"480p_t2v": {
|
||||
"target_size": 640,
|
||||
"task_type": "i2v",
|
||||
},
|
||||
"720p_t2v": {
|
||||
"target_size": 960,
|
||||
"task_type": "t2v",
|
||||
},
|
||||
"720p_i2v": {
|
||||
"target_size": 960,
|
||||
"task_type": "i2v",
|
||||
},
|
||||
"480p_t2v_distilled": {
|
||||
"target_size": 640,
|
||||
"task_type": "t2v",
|
||||
},
|
||||
"480p_i2v_distilled": {
|
||||
"target_size": 640,
|
||||
"task_type": "i2v",
|
||||
},
|
||||
"720p_i2v_distilled": {
|
||||
"target_size": 960,
|
||||
"task_type": "i2v",
|
||||
},
|
||||
}
|
||||
|
||||
SCHEDULER_CONFIGS = {
|
||||
"480p_t2v": {
|
||||
"shift": 5.0,
|
||||
},
|
||||
"480p_i2v": {
|
||||
"shift": 5.0,
|
||||
},
|
||||
"720p_t2v": {
|
||||
"shift": 9.0,
|
||||
},
|
||||
"720p_i2v": {
|
||||
"shift": 7.0,
|
||||
},
|
||||
"480p_t2v_distilled": {
|
||||
"shift": 5.0,
|
||||
},
|
||||
"480p_i2v_distilled": {
|
||||
"shift": 5.0,
|
||||
},
|
||||
"720p_i2v_distilled": {
|
||||
"shift": 7.0,
|
||||
},
|
||||
}
|
||||
|
||||
GUIDANCE_CONFIGS = {
|
||||
"480p_t2v": {
|
||||
"guidance_scale": 6.0,
|
||||
},
|
||||
"480p_i2v": {
|
||||
"guidance_scale": 6.0,
|
||||
},
|
||||
"720p_t2v": {
|
||||
"guidance_scale": 6.0,
|
||||
},
|
||||
"720p_i2v": {
|
||||
"guidance_scale": 6.0,
|
||||
},
|
||||
"480p_t2v_distilled": {
|
||||
"guidance_scale": 1.0,
|
||||
},
|
||||
"480p_i2v_distilled": {
|
||||
"guidance_scale": 1.0,
|
||||
},
|
||||
"720p_i2v_distilled": {
|
||||
"guidance_scale": 1.0,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def swap_scale_shift(weight):
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
return new_weight
|
||||
|
||||
|
||||
def convert_hyvideo15_transformer_to_diffusers(original_state_dict):
|
||||
"""
|
||||
Convert HunyuanVideo 1.5 original checkpoint to Diffusers format.
|
||||
"""
|
||||
converted_state_dict = {}
|
||||
|
||||
# 1. time_embed.timestep_embedder <- time_in
|
||||
converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
|
||||
"time_in.mlp.0.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop("time_in.mlp.0.bias")
|
||||
converted_state_dict["time_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
|
||||
"time_in.mlp.2.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop("time_in.mlp.2.bias")
|
||||
|
||||
# 2. context_embedder.time_text_embed.timestep_embedder <- txt_in.t_embedder
|
||||
converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_1.weight"] = (
|
||||
original_state_dict.pop("txt_in.t_embedder.mlp.0.weight")
|
||||
)
|
||||
converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop(
|
||||
"txt_in.t_embedder.mlp.0.bias"
|
||||
)
|
||||
converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_2.weight"] = (
|
||||
original_state_dict.pop("txt_in.t_embedder.mlp.2.weight")
|
||||
)
|
||||
converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop(
|
||||
"txt_in.t_embedder.mlp.2.bias"
|
||||
)
|
||||
|
||||
# 3. context_embedder.time_text_embed.text_embedder <- txt_in.c_embedder
|
||||
converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_1.weight"] = original_state_dict.pop(
|
||||
"txt_in.c_embedder.linear_1.weight"
|
||||
)
|
||||
converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_1.bias"] = original_state_dict.pop(
|
||||
"txt_in.c_embedder.linear_1.bias"
|
||||
)
|
||||
converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_2.weight"] = original_state_dict.pop(
|
||||
"txt_in.c_embedder.linear_2.weight"
|
||||
)
|
||||
converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_2.bias"] = original_state_dict.pop(
|
||||
"txt_in.c_embedder.linear_2.bias"
|
||||
)
|
||||
|
||||
# 4. context_embedder.proj_in <- txt_in.input_embedder
|
||||
converted_state_dict["context_embedder.proj_in.weight"] = original_state_dict.pop("txt_in.input_embedder.weight")
|
||||
converted_state_dict["context_embedder.proj_in.bias"] = original_state_dict.pop("txt_in.input_embedder.bias")
|
||||
|
||||
# 5. context_embedder.token_refiner <- txt_in.individual_token_refiner
|
||||
num_refiner_blocks = 2
|
||||
for i in range(num_refiner_blocks):
|
||||
block_prefix = f"context_embedder.token_refiner.refiner_blocks.{i}."
|
||||
orig_prefix = f"txt_in.individual_token_refiner.blocks.{i}."
|
||||
|
||||
# norm1
|
||||
converted_state_dict[f"{block_prefix}norm1.weight"] = original_state_dict.pop(f"{orig_prefix}norm1.weight")
|
||||
converted_state_dict[f"{block_prefix}norm1.bias"] = original_state_dict.pop(f"{orig_prefix}norm1.bias")
|
||||
|
||||
# Split self_attn_qkv into to_q, to_k, to_v
|
||||
qkv_weight = original_state_dict.pop(f"{orig_prefix}self_attn_qkv.weight")
|
||||
qkv_bias = original_state_dict.pop(f"{orig_prefix}self_attn_qkv.bias")
|
||||
q, k, v = torch.chunk(qkv_weight, 3, dim=0)
|
||||
q_bias, k_bias, v_bias = torch.chunk(qkv_bias, 3, dim=0)
|
||||
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = q
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = q_bias
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = k
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = k_bias
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = v
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = v_bias
|
||||
|
||||
# self_attn_proj -> attn.to_out.0
|
||||
converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}self_attn_proj.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = original_state_dict.pop(
|
||||
f"{orig_prefix}self_attn_proj.bias"
|
||||
)
|
||||
|
||||
# norm2
|
||||
converted_state_dict[f"{block_prefix}norm2.weight"] = original_state_dict.pop(f"{orig_prefix}norm2.weight")
|
||||
converted_state_dict[f"{block_prefix}norm2.bias"] = original_state_dict.pop(f"{orig_prefix}norm2.bias")
|
||||
|
||||
# mlp -> ff
|
||||
converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}mlp.fc1.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = original_state_dict.pop(
|
||||
f"{orig_prefix}mlp.fc1.bias"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}ff.net.2.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}mlp.fc2.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}ff.net.2.bias"] = original_state_dict.pop(f"{orig_prefix}mlp.fc2.bias")
|
||||
|
||||
# adaLN_modulation -> norm_out
|
||||
converted_state_dict[f"{block_prefix}norm_out.linear.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}adaLN_modulation.1.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}norm_out.linear.bias"] = original_state_dict.pop(
|
||||
f"{orig_prefix}adaLN_modulation.1.bias"
|
||||
)
|
||||
|
||||
# 6. context_embedder_2 <- byt5_in
|
||||
converted_state_dict["context_embedder_2.norm.weight"] = original_state_dict.pop("byt5_in.layernorm.weight")
|
||||
converted_state_dict["context_embedder_2.norm.bias"] = original_state_dict.pop("byt5_in.layernorm.bias")
|
||||
converted_state_dict["context_embedder_2.linear_1.weight"] = original_state_dict.pop("byt5_in.fc1.weight")
|
||||
converted_state_dict["context_embedder_2.linear_1.bias"] = original_state_dict.pop("byt5_in.fc1.bias")
|
||||
converted_state_dict["context_embedder_2.linear_2.weight"] = original_state_dict.pop("byt5_in.fc2.weight")
|
||||
converted_state_dict["context_embedder_2.linear_2.bias"] = original_state_dict.pop("byt5_in.fc2.bias")
|
||||
converted_state_dict["context_embedder_2.linear_3.weight"] = original_state_dict.pop("byt5_in.fc3.weight")
|
||||
converted_state_dict["context_embedder_2.linear_3.bias"] = original_state_dict.pop("byt5_in.fc3.bias")
|
||||
|
||||
# 7. image_embedder <- vision_in
|
||||
converted_state_dict["image_embedder.norm_in.weight"] = original_state_dict.pop("vision_in.proj.0.weight")
|
||||
converted_state_dict["image_embedder.norm_in.bias"] = original_state_dict.pop("vision_in.proj.0.bias")
|
||||
converted_state_dict["image_embedder.linear_1.weight"] = original_state_dict.pop("vision_in.proj.1.weight")
|
||||
converted_state_dict["image_embedder.linear_1.bias"] = original_state_dict.pop("vision_in.proj.1.bias")
|
||||
converted_state_dict["image_embedder.linear_2.weight"] = original_state_dict.pop("vision_in.proj.3.weight")
|
||||
converted_state_dict["image_embedder.linear_2.bias"] = original_state_dict.pop("vision_in.proj.3.bias")
|
||||
converted_state_dict["image_embedder.norm_out.weight"] = original_state_dict.pop("vision_in.proj.4.weight")
|
||||
converted_state_dict["image_embedder.norm_out.bias"] = original_state_dict.pop("vision_in.proj.4.bias")
|
||||
|
||||
# 8. x_embedder <- img_in
|
||||
converted_state_dict["x_embedder.proj.weight"] = original_state_dict.pop("img_in.proj.weight")
|
||||
converted_state_dict["x_embedder.proj.bias"] = original_state_dict.pop("img_in.proj.bias")
|
||||
|
||||
# 9. cond_type_embed <- cond_type_embedding
|
||||
converted_state_dict["cond_type_embed.weight"] = original_state_dict.pop("cond_type_embedding.weight")
|
||||
|
||||
# 10. transformer_blocks <- double_blocks
|
||||
num_layers = 54
|
||||
for i in range(num_layers):
|
||||
block_prefix = f"transformer_blocks.{i}."
|
||||
orig_prefix = f"double_blocks.{i}."
|
||||
|
||||
# norm1 (img_mod)
|
||||
converted_state_dict[f"{block_prefix}norm1.linear.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}img_mod.linear.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}norm1.linear.bias"] = original_state_dict.pop(
|
||||
f"{orig_prefix}img_mod.linear.bias"
|
||||
)
|
||||
|
||||
# norm1_context (txt_mod)
|
||||
converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}txt_mod.linear.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = original_state_dict.pop(
|
||||
f"{orig_prefix}txt_mod.linear.bias"
|
||||
)
|
||||
|
||||
# img attention (to_q, to_k, to_v)
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}img_attn_q.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = original_state_dict.pop(
|
||||
f"{orig_prefix}img_attn_q.bias"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}img_attn_k.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = original_state_dict.pop(
|
||||
f"{orig_prefix}img_attn_k.bias"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}img_attn_v.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = original_state_dict.pop(
|
||||
f"{orig_prefix}img_attn_v.bias"
|
||||
)
|
||||
|
||||
# img attention qk norm
|
||||
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}img_attn_q_norm.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}img_attn_k_norm.weight"
|
||||
)
|
||||
|
||||
# img attention output projection
|
||||
converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}img_attn_proj.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = original_state_dict.pop(
|
||||
f"{orig_prefix}img_attn_proj.bias"
|
||||
)
|
||||
|
||||
# txt attention (add_q_proj, add_k_proj, add_v_proj)
|
||||
converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}txt_attn_q.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = original_state_dict.pop(
|
||||
f"{orig_prefix}txt_attn_q.bias"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}txt_attn_k.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = original_state_dict.pop(
|
||||
f"{orig_prefix}txt_attn_k.bias"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}txt_attn_v.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = original_state_dict.pop(
|
||||
f"{orig_prefix}txt_attn_v.bias"
|
||||
)
|
||||
|
||||
# txt attention qk norm
|
||||
converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}txt_attn_q_norm.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}txt_attn_k_norm.weight"
|
||||
)
|
||||
|
||||
# txt attention output projection
|
||||
converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}txt_attn_proj.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = original_state_dict.pop(
|
||||
f"{orig_prefix}txt_attn_proj.bias"
|
||||
)
|
||||
|
||||
# norm2 and norm2_context (these don't have weights in the original, they're LayerNorm with elementwise_affine=False)
|
||||
# So we skip them
|
||||
|
||||
# img_mlp -> ff
|
||||
converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}img_mlp.fc1.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = original_state_dict.pop(
|
||||
f"{orig_prefix}img_mlp.fc1.bias"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}ff.net.2.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}img_mlp.fc2.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}ff.net.2.bias"] = original_state_dict.pop(
|
||||
f"{orig_prefix}img_mlp.fc2.bias"
|
||||
)
|
||||
|
||||
# txt_mlp -> ff_context
|
||||
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}txt_mlp.fc1.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = original_state_dict.pop(
|
||||
f"{orig_prefix}txt_mlp.fc1.bias"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}txt_mlp.fc2.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = original_state_dict.pop(
|
||||
f"{orig_prefix}txt_mlp.fc2.bias"
|
||||
)
|
||||
|
||||
# 11. norm_out and proj_out <- final_layer
|
||||
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
|
||||
original_state_dict.pop("final_layer.adaLN_modulation.1.weight")
|
||||
)
|
||||
converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
|
||||
original_state_dict.pop("final_layer.adaLN_modulation.1.bias")
|
||||
)
|
||||
converted_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight")
|
||||
converted_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias")
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def convert_hunyuan_video_15_vae_checkpoint_to_diffusers(
|
||||
original_state_dict, block_out_channels=[128, 256, 512, 1024, 1024], layers_per_block=2
|
||||
):
|
||||
converted = {}
|
||||
|
||||
# 1. Encoder
|
||||
# 1.1 conv_in
|
||||
converted["encoder.conv_in.conv.weight"] = original_state_dict.pop("encoder.conv_in.conv.weight")
|
||||
converted["encoder.conv_in.conv.bias"] = original_state_dict.pop("encoder.conv_in.conv.bias")
|
||||
|
||||
# 1.2 Down blocks
|
||||
for down_block_index in range(len(block_out_channels)): # 0 to 4
|
||||
# ResNet blocks
|
||||
for resnet_block_index in range(layers_per_block): # 0 to 1
|
||||
converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.norm1.gamma"] = (
|
||||
original_state_dict.pop(f"encoder.down.{down_block_index}.block.{resnet_block_index}.norm1.gamma")
|
||||
)
|
||||
converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.conv1.conv.weight"] = (
|
||||
original_state_dict.pop(
|
||||
f"encoder.down.{down_block_index}.block.{resnet_block_index}.conv1.conv.weight"
|
||||
)
|
||||
)
|
||||
converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.conv1.conv.bias"] = (
|
||||
original_state_dict.pop(f"encoder.down.{down_block_index}.block.{resnet_block_index}.conv1.conv.bias")
|
||||
)
|
||||
converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.norm2.gamma"] = (
|
||||
original_state_dict.pop(f"encoder.down.{down_block_index}.block.{resnet_block_index}.norm2.gamma")
|
||||
)
|
||||
converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.conv2.conv.weight"] = (
|
||||
original_state_dict.pop(
|
||||
f"encoder.down.{down_block_index}.block.{resnet_block_index}.conv2.conv.weight"
|
||||
)
|
||||
)
|
||||
converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.conv2.conv.bias"] = (
|
||||
original_state_dict.pop(f"encoder.down.{down_block_index}.block.{resnet_block_index}.conv2.conv.bias")
|
||||
)
|
||||
|
||||
# Downsample (if exists)
|
||||
if f"encoder.down.{down_block_index}.downsample.conv.conv.weight" in original_state_dict:
|
||||
converted[f"encoder.down_blocks.{down_block_index}.downsamplers.0.conv.conv.weight"] = (
|
||||
original_state_dict.pop(f"encoder.down.{down_block_index}.downsample.conv.conv.weight")
|
||||
)
|
||||
converted[f"encoder.down_blocks.{down_block_index}.downsamplers.0.conv.conv.bias"] = (
|
||||
original_state_dict.pop(f"encoder.down.{down_block_index}.downsample.conv.conv.bias")
|
||||
)
|
||||
|
||||
# 1.3 Mid block
|
||||
converted["encoder.mid_block.resnets.0.norm1.gamma"] = original_state_dict.pop("encoder.mid.block_1.norm1.gamma")
|
||||
converted["encoder.mid_block.resnets.0.conv1.conv.weight"] = original_state_dict.pop(
|
||||
"encoder.mid.block_1.conv1.conv.weight"
|
||||
)
|
||||
converted["encoder.mid_block.resnets.0.conv1.conv.bias"] = original_state_dict.pop(
|
||||
"encoder.mid.block_1.conv1.conv.bias"
|
||||
)
|
||||
converted["encoder.mid_block.resnets.0.norm2.gamma"] = original_state_dict.pop("encoder.mid.block_1.norm2.gamma")
|
||||
converted["encoder.mid_block.resnets.0.conv2.conv.weight"] = original_state_dict.pop(
|
||||
"encoder.mid.block_1.conv2.conv.weight"
|
||||
)
|
||||
converted["encoder.mid_block.resnets.0.conv2.conv.bias"] = original_state_dict.pop(
|
||||
"encoder.mid.block_1.conv2.conv.bias"
|
||||
)
|
||||
|
||||
converted["encoder.mid_block.resnets.1.norm1.gamma"] = original_state_dict.pop("encoder.mid.block_2.norm1.gamma")
|
||||
converted["encoder.mid_block.resnets.1.conv1.conv.weight"] = original_state_dict.pop(
|
||||
"encoder.mid.block_2.conv1.conv.weight"
|
||||
)
|
||||
converted["encoder.mid_block.resnets.1.conv1.conv.bias"] = original_state_dict.pop(
|
||||
"encoder.mid.block_2.conv1.conv.bias"
|
||||
)
|
||||
converted["encoder.mid_block.resnets.1.norm2.gamma"] = original_state_dict.pop("encoder.mid.block_2.norm2.gamma")
|
||||
converted["encoder.mid_block.resnets.1.conv2.conv.weight"] = original_state_dict.pop(
|
||||
"encoder.mid.block_2.conv2.conv.weight"
|
||||
)
|
||||
converted["encoder.mid_block.resnets.1.conv2.conv.bias"] = original_state_dict.pop(
|
||||
"encoder.mid.block_2.conv2.conv.bias"
|
||||
)
|
||||
|
||||
# Attention block
|
||||
converted["encoder.mid_block.attentions.0.norm.gamma"] = original_state_dict.pop("encoder.mid.attn_1.norm.gamma")
|
||||
converted["encoder.mid_block.attentions.0.to_q.weight"] = original_state_dict.pop("encoder.mid.attn_1.q.weight")
|
||||
converted["encoder.mid_block.attentions.0.to_q.bias"] = original_state_dict.pop("encoder.mid.attn_1.q.bias")
|
||||
converted["encoder.mid_block.attentions.0.to_k.weight"] = original_state_dict.pop("encoder.mid.attn_1.k.weight")
|
||||
converted["encoder.mid_block.attentions.0.to_k.bias"] = original_state_dict.pop("encoder.mid.attn_1.k.bias")
|
||||
converted["encoder.mid_block.attentions.0.to_v.weight"] = original_state_dict.pop("encoder.mid.attn_1.v.weight")
|
||||
converted["encoder.mid_block.attentions.0.to_v.bias"] = original_state_dict.pop("encoder.mid.attn_1.v.bias")
|
||||
converted["encoder.mid_block.attentions.0.proj_out.weight"] = original_state_dict.pop(
|
||||
"encoder.mid.attn_1.proj_out.weight"
|
||||
)
|
||||
converted["encoder.mid_block.attentions.0.proj_out.bias"] = original_state_dict.pop(
|
||||
"encoder.mid.attn_1.proj_out.bias"
|
||||
)
|
||||
|
||||
# 1.4 Encoder output
|
||||
converted["encoder.norm_out.gamma"] = original_state_dict.pop("encoder.norm_out.gamma")
|
||||
converted["encoder.conv_out.conv.weight"] = original_state_dict.pop("encoder.conv_out.conv.weight")
|
||||
converted["encoder.conv_out.conv.bias"] = original_state_dict.pop("encoder.conv_out.conv.bias")
|
||||
|
||||
# 2. Decoder
|
||||
# 2.1 conv_in
|
||||
converted["decoder.conv_in.conv.weight"] = original_state_dict.pop("decoder.conv_in.conv.weight")
|
||||
converted["decoder.conv_in.conv.bias"] = original_state_dict.pop("decoder.conv_in.conv.bias")
|
||||
|
||||
# 2.2 Mid block
|
||||
converted["decoder.mid_block.resnets.0.norm1.gamma"] = original_state_dict.pop("decoder.mid.block_1.norm1.gamma")
|
||||
converted["decoder.mid_block.resnets.0.conv1.conv.weight"] = original_state_dict.pop(
|
||||
"decoder.mid.block_1.conv1.conv.weight"
|
||||
)
|
||||
converted["decoder.mid_block.resnets.0.conv1.conv.bias"] = original_state_dict.pop(
|
||||
"decoder.mid.block_1.conv1.conv.bias"
|
||||
)
|
||||
converted["decoder.mid_block.resnets.0.norm2.gamma"] = original_state_dict.pop("decoder.mid.block_1.norm2.gamma")
|
||||
converted["decoder.mid_block.resnets.0.conv2.conv.weight"] = original_state_dict.pop(
|
||||
"decoder.mid.block_1.conv2.conv.weight"
|
||||
)
|
||||
converted["decoder.mid_block.resnets.0.conv2.conv.bias"] = original_state_dict.pop(
|
||||
"decoder.mid.block_1.conv2.conv.bias"
|
||||
)
|
||||
|
||||
converted["decoder.mid_block.resnets.1.norm1.gamma"] = original_state_dict.pop("decoder.mid.block_2.norm1.gamma")
|
||||
converted["decoder.mid_block.resnets.1.conv1.conv.weight"] = original_state_dict.pop(
|
||||
"decoder.mid.block_2.conv1.conv.weight"
|
||||
)
|
||||
converted["decoder.mid_block.resnets.1.conv1.conv.bias"] = original_state_dict.pop(
|
||||
"decoder.mid.block_2.conv1.conv.bias"
|
||||
)
|
||||
converted["decoder.mid_block.resnets.1.norm2.gamma"] = original_state_dict.pop("decoder.mid.block_2.norm2.gamma")
|
||||
converted["decoder.mid_block.resnets.1.conv2.conv.weight"] = original_state_dict.pop(
|
||||
"decoder.mid.block_2.conv2.conv.weight"
|
||||
)
|
||||
converted["decoder.mid_block.resnets.1.conv2.conv.bias"] = original_state_dict.pop(
|
||||
"decoder.mid.block_2.conv2.conv.bias"
|
||||
)
|
||||
|
||||
# Decoder attention block
|
||||
converted["decoder.mid_block.attentions.0.norm.gamma"] = original_state_dict.pop("decoder.mid.attn_1.norm.gamma")
|
||||
converted["decoder.mid_block.attentions.0.to_q.weight"] = original_state_dict.pop("decoder.mid.attn_1.q.weight")
|
||||
converted["decoder.mid_block.attentions.0.to_q.bias"] = original_state_dict.pop("decoder.mid.attn_1.q.bias")
|
||||
converted["decoder.mid_block.attentions.0.to_k.weight"] = original_state_dict.pop("decoder.mid.attn_1.k.weight")
|
||||
converted["decoder.mid_block.attentions.0.to_k.bias"] = original_state_dict.pop("decoder.mid.attn_1.k.bias")
|
||||
converted["decoder.mid_block.attentions.0.to_v.weight"] = original_state_dict.pop("decoder.mid.attn_1.v.weight")
|
||||
converted["decoder.mid_block.attentions.0.to_v.bias"] = original_state_dict.pop("decoder.mid.attn_1.v.bias")
|
||||
converted["decoder.mid_block.attentions.0.proj_out.weight"] = original_state_dict.pop(
|
||||
"decoder.mid.attn_1.proj_out.weight"
|
||||
)
|
||||
converted["decoder.mid_block.attentions.0.proj_out.bias"] = original_state_dict.pop(
|
||||
"decoder.mid.attn_1.proj_out.bias"
|
||||
)
|
||||
|
||||
# 2.3 Up blocks
|
||||
for up_block_index in range(len(block_out_channels)): # 0 to 5
|
||||
# ResNet blocks
|
||||
for resnet_block_index in range(layers_per_block + 1): # 0 to 2 (decoder has 3 resnets per level)
|
||||
converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.norm1.gamma"] = (
|
||||
original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.norm1.gamma")
|
||||
)
|
||||
converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.conv1.conv.weight"] = (
|
||||
original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.conv1.conv.weight")
|
||||
)
|
||||
converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.conv1.conv.bias"] = (
|
||||
original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.conv1.conv.bias")
|
||||
)
|
||||
converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.norm2.gamma"] = (
|
||||
original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.norm2.gamma")
|
||||
)
|
||||
converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.conv2.conv.weight"] = (
|
||||
original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.conv2.conv.weight")
|
||||
)
|
||||
converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.conv2.conv.bias"] = (
|
||||
original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.conv2.conv.bias")
|
||||
)
|
||||
|
||||
# Upsample (if exists)
|
||||
if f"decoder.up.{up_block_index}.upsample.conv.conv.weight" in original_state_dict:
|
||||
converted[f"decoder.up_blocks.{up_block_index}.upsamplers.0.conv.conv.weight"] = original_state_dict.pop(
|
||||
f"decoder.up.{up_block_index}.upsample.conv.conv.weight"
|
||||
)
|
||||
converted[f"decoder.up_blocks.{up_block_index}.upsamplers.0.conv.conv.bias"] = original_state_dict.pop(
|
||||
f"decoder.up.{up_block_index}.upsample.conv.conv.bias"
|
||||
)
|
||||
|
||||
# 2.4 Decoder output
|
||||
converted["decoder.norm_out.gamma"] = original_state_dict.pop("decoder.norm_out.gamma")
|
||||
converted["decoder.conv_out.conv.weight"] = original_state_dict.pop("decoder.conv_out.conv.weight")
|
||||
converted["decoder.conv_out.conv.bias"] = original_state_dict.pop("decoder.conv_out.conv.bias")
|
||||
|
||||
return converted
|
||||
|
||||
|
||||
def load_sharded_safetensors(dir: pathlib.Path):
|
||||
file_paths = list(dir.glob("diffusion_pytorch_model*.safetensors"))
|
||||
state_dict = {}
|
||||
for path in file_paths:
|
||||
state_dict.update(load_file(path))
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_original_transformer_state_dict(args):
|
||||
if args.original_state_dict_repo_id is not None:
|
||||
model_dir = snapshot_download(
|
||||
args.original_state_dict_repo_id,
|
||||
repo_type="model",
|
||||
allow_patterns="transformer/" + args.transformer_type + "/*",
|
||||
)
|
||||
elif args.original_state_dict_folder is not None:
|
||||
model_dir = pathlib.Path(args.original_state_dict_folder)
|
||||
else:
|
||||
raise ValueError("Please provide either `original_state_dict_repo_id` or `original_state_dict_folder`")
|
||||
model_dir = pathlib.Path(model_dir)
|
||||
model_dir = model_dir / "transformer" / args.transformer_type
|
||||
return load_sharded_safetensors(model_dir)
|
||||
|
||||
|
||||
def load_original_vae_state_dict(args):
|
||||
if args.original_state_dict_repo_id is not None:
|
||||
ckpt_path = hf_hub_download(
|
||||
repo_id=args.original_state_dict_repo_id, filename="vae/diffusion_pytorch_model.safetensors"
|
||||
)
|
||||
elif args.original_state_dict_folder is not None:
|
||||
model_dir = pathlib.Path(args.original_state_dict_folder)
|
||||
ckpt_path = model_dir / "vae/diffusion_pytorch_model.safetensors"
|
||||
else:
|
||||
raise ValueError("Please provide either `original_state_dict_repo_id` or `original_state_dict_folder`")
|
||||
|
||||
original_state_dict = load_file(ckpt_path)
|
||||
return original_state_dict
|
||||
|
||||
|
||||
def convert_transformer(args):
|
||||
original_state_dict = load_original_transformer_state_dict(args)
|
||||
|
||||
config = TRANSFORMER_CONFIGS[args.transformer_type]
|
||||
with init_empty_weights():
|
||||
transformer = HunyuanVideo15Transformer3DModel(**config)
|
||||
state_dict = convert_hyvideo15_transformer_to_diffusers(original_state_dict)
|
||||
transformer.load_state_dict(state_dict, strict=True, assign=True)
|
||||
|
||||
return transformer
|
||||
|
||||
|
||||
def convert_vae(args):
|
||||
original_state_dict = load_original_vae_state_dict(args)
|
||||
with init_empty_weights():
|
||||
vae = AutoencoderKLHunyuanVideo15()
|
||||
state_dict = convert_hunyuan_video_15_vae_checkpoint_to_diffusers(original_state_dict)
|
||||
vae.load_state_dict(state_dict, strict=True, assign=True)
|
||||
return vae
|
||||
|
||||
|
||||
def load_mllm():
|
||||
print(" loading from Qwen/Qwen2.5-VL-7B-Instruct")
|
||||
text_encoder = AutoModel.from_pretrained(
|
||||
"Qwen/Qwen2.5-VL-7B-Instruct", torch_dtype=torch.bfloat16, low_cpu_mem_usage=True
|
||||
)
|
||||
if hasattr(text_encoder, "language_model"):
|
||||
text_encoder = text_encoder.language_model
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", padding_side="right")
|
||||
return text_encoder, tokenizer
|
||||
|
||||
|
||||
# copied from https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/910da2a829c484ea28982e8cff3bbc2cacdf1681/hyvideo/models/text_encoders/byT5/__init__.py#L89
|
||||
def add_special_token(
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
add_color=True,
|
||||
add_font=True,
|
||||
multilingual=True,
|
||||
color_ann_path="assets/color_idx.json",
|
||||
font_ann_path="assets/multilingual_10-lang_idx.json",
|
||||
):
|
||||
"""
|
||||
Add special tokens for color and font to tokenizer and text encoder.
|
||||
|
||||
Args:
|
||||
tokenizer: Huggingface tokenizer.
|
||||
text_encoder: Huggingface T5 encoder.
|
||||
add_color (bool): Whether to add color tokens.
|
||||
add_font (bool): Whether to add font tokens.
|
||||
color_ann_path (str): Path to color annotation JSON.
|
||||
font_ann_path (str): Path to font annotation JSON.
|
||||
multilingual (bool): Whether to use multilingual font tokens.
|
||||
"""
|
||||
with open(font_ann_path, "r") as f:
|
||||
idx_font_dict = json.load(f)
|
||||
with open(color_ann_path, "r") as f:
|
||||
idx_color_dict = json.load(f)
|
||||
|
||||
if multilingual:
|
||||
font_token = [f"<{font_code[:2]}-font-{idx_font_dict[font_code]}>" for font_code in idx_font_dict]
|
||||
else:
|
||||
font_token = [f"<font-{i}>" for i in range(len(idx_font_dict))]
|
||||
color_token = [f"<color-{i}>" for i in range(len(idx_color_dict))]
|
||||
additional_special_tokens = []
|
||||
if add_color:
|
||||
additional_special_tokens += color_token
|
||||
if add_font:
|
||||
additional_special_tokens += font_token
|
||||
|
||||
tokenizer.add_tokens(additional_special_tokens, special_tokens=True)
|
||||
# Set mean_resizing=False to avoid PyTorch LAPACK dependency
|
||||
text_encoder.resize_token_embeddings(len(tokenizer), mean_resizing=False)
|
||||
|
||||
|
||||
def load_byt5(args):
|
||||
"""
|
||||
Load ByT5 encoder with Glyph-SDXL-v2 weights and save in HuggingFace format.
|
||||
"""
|
||||
|
||||
# 1. Load base tokenizer and encoder
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/byt5-small")
|
||||
|
||||
# Load as T5EncoderModel
|
||||
encoder = T5EncoderModel.from_pretrained("google/byt5-small")
|
||||
|
||||
byt5_checkpoint_path = os.path.join(args.byt5_path, "checkpoints/byt5_model.pt")
|
||||
color_ann_path = os.path.join(args.byt5_path, "assets/color_idx.json")
|
||||
font_ann_path = os.path.join(args.byt5_path, "assets/multilingual_10-lang_idx.json")
|
||||
|
||||
# 2. Add special tokens
|
||||
add_special_token(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=encoder,
|
||||
add_color=True,
|
||||
add_font=True,
|
||||
color_ann_path=color_ann_path,
|
||||
font_ann_path=font_ann_path,
|
||||
multilingual=True,
|
||||
)
|
||||
|
||||
# 3. Load Glyph-SDXL-v2 checkpoint
|
||||
print(f"\n3. Loading Glyph-SDXL-v2 checkpoint: {byt5_checkpoint_path}")
|
||||
checkpoint = torch.load(byt5_checkpoint_path, map_location="cpu")
|
||||
|
||||
# Handle different checkpoint formats
|
||||
if "state_dict" in checkpoint:
|
||||
state_dict = checkpoint["state_dict"]
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
# add 'encoder.' prefix to the keys
|
||||
# Remove 'module.text_tower.encoder.' prefix if present
|
||||
cleaned_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
if key.startswith("module.text_tower.encoder."):
|
||||
new_key = "encoder." + key[len("module.text_tower.encoder.") :]
|
||||
cleaned_state_dict[new_key] = value
|
||||
else:
|
||||
new_key = "encoder." + key
|
||||
cleaned_state_dict[new_key] = value
|
||||
|
||||
# 4. Load weights
|
||||
missing_keys, unexpected_keys = encoder.load_state_dict(cleaned_state_dict, strict=False)
|
||||
if unexpected_keys:
|
||||
raise ValueError(f"Unexpected keys: {unexpected_keys}")
|
||||
if "shared.weight" in missing_keys:
|
||||
print(" Missing shared.weight as expected")
|
||||
missing_keys.remove("shared.weight")
|
||||
if missing_keys:
|
||||
raise ValueError(f"Missing keys: {missing_keys}")
|
||||
|
||||
return encoder, tokenizer
|
||||
|
||||
|
||||
def load_siglip():
|
||||
image_encoder = SiglipVisionModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-Redux-dev", subfolder="image_encoder", torch_dtype=torch.bfloat16
|
||||
)
|
||||
feature_extractor = SiglipImageProcessor.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-Redux-dev", subfolder="feature_extractor"
|
||||
)
|
||||
return image_encoder, feature_extractor
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--original_state_dict_repo_id", type=str, default=None, help="Path to original hub_id for the model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--original_state_dict_folder", type=str, default=None, help="Local folder name of the original state dict"
|
||||
)
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model(s) should be saved")
|
||||
parser.add_argument("--transformer_type", type=str, default="480p_i2v", choices=list(TRANSFORMER_CONFIGS.keys()))
|
||||
parser.add_argument(
|
||||
"--byt5_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"path to the downloaded byt5 checkpoint & assets. "
|
||||
"Note: They use Glyph-SDXL-v2 as byt5 encoder. You can download from modelscope like: "
|
||||
"`modelscope download --model AI-ModelScope/Glyph-SDXL-v2 --local_dir ./ckpts/text_encoder/Glyph-SDXL-v2` "
|
||||
"or manually download following the instructions on "
|
||||
"https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/910da2a829c484ea28982e8cff3bbc2cacdf1681/checkpoints-download.md. "
|
||||
"The path should point to the Glyph-SDXL-v2 folder which should contain an `assets` folder and a `checkpoints` folder, "
|
||||
"like: Glyph-SDXL-v2/assets/... and Glyph-SDXL-v2/checkpoints/byt5_model.pt"
|
||||
),
|
||||
)
|
||||
parser.add_argument("--save_pipeline", action="store_true")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
if args.save_pipeline and args.byt5_path is None:
|
||||
raise ValueError("Please provide --byt5_path when saving pipeline")
|
||||
|
||||
transformer = None
|
||||
|
||||
transformer = convert_transformer(args)
|
||||
if not args.save_pipeline:
|
||||
transformer.save_pretrained(args.output_path, safe_serialization=True)
|
||||
else:
|
||||
task_type = transformer.config.task_type
|
||||
|
||||
vae = convert_vae(args)
|
||||
|
||||
text_encoder, tokenizer = load_mllm()
|
||||
text_encoder_2, tokenizer_2 = load_byt5(args)
|
||||
|
||||
flow_shift = SCHEDULER_CONFIGS[args.transformer_type]["shift"]
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)
|
||||
|
||||
guidance_scale = GUIDANCE_CONFIGS[args.transformer_type]["guidance_scale"]
|
||||
guider = ClassifierFreeGuidance(guidance_scale=guidance_scale)
|
||||
|
||||
if task_type == "i2v":
|
||||
image_encoder, feature_extractor = load_siglip()
|
||||
pipeline = HunyuanVideo15ImageToVideoPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_2=tokenizer_2,
|
||||
transformer=transformer,
|
||||
guider=guider,
|
||||
scheduler=scheduler,
|
||||
image_encoder=image_encoder,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
elif task_type == "t2v":
|
||||
pipeline = HunyuanVideo15Pipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_2=tokenizer_2,
|
||||
transformer=transformer,
|
||||
guider=guider,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Task type {task_type} is not supported")
|
||||
|
||||
pipeline.save_pretrained(args.output_path, safe_serialization=True)
|
||||
@@ -190,6 +190,7 @@ else:
|
||||
"AutoencoderKLHunyuanImage",
|
||||
"AutoencoderKLHunyuanImageRefiner",
|
||||
"AutoencoderKLHunyuanVideo",
|
||||
"AutoencoderKLHunyuanVideo15",
|
||||
"AutoencoderKLLTXVideo",
|
||||
"AutoencoderKLMagvit",
|
||||
"AutoencoderKLMochi",
|
||||
@@ -225,6 +226,7 @@ else:
|
||||
"HunyuanDiT2DModel",
|
||||
"HunyuanDiT2DMultiControlNetModel",
|
||||
"HunyuanImageTransformer2DModel",
|
||||
"HunyuanVideo15Transformer3DModel",
|
||||
"HunyuanVideoFramepackTransformer3DModel",
|
||||
"HunyuanVideoTransformer3DModel",
|
||||
"I2VGenXLUNet",
|
||||
@@ -481,6 +483,8 @@ else:
|
||||
"HunyuanImagePipeline",
|
||||
"HunyuanImageRefinerPipeline",
|
||||
"HunyuanSkyreelsImageToVideoPipeline",
|
||||
"HunyuanVideo15ImageToVideoPipeline",
|
||||
"HunyuanVideo15Pipeline",
|
||||
"HunyuanVideoFramepackPipeline",
|
||||
"HunyuanVideoImageToVideoPipeline",
|
||||
"HunyuanVideoPipeline",
|
||||
@@ -909,6 +913,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLHunyuanImage,
|
||||
AutoencoderKLHunyuanImageRefiner,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLHunyuanVideo15,
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLMagvit,
|
||||
AutoencoderKLMochi,
|
||||
@@ -944,6 +949,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
HunyuanDiT2DModel,
|
||||
HunyuanDiT2DMultiControlNetModel,
|
||||
HunyuanImageTransformer2DModel,
|
||||
HunyuanVideo15Transformer3DModel,
|
||||
HunyuanVideoFramepackTransformer3DModel,
|
||||
HunyuanVideoTransformer3DModel,
|
||||
I2VGenXLUNet,
|
||||
@@ -1170,6 +1176,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
HunyuanImagePipeline,
|
||||
HunyuanImageRefinerPipeline,
|
||||
HunyuanSkyreelsImageToVideoPipeline,
|
||||
HunyuanVideo15ImageToVideoPipeline,
|
||||
HunyuanVideo15Pipeline,
|
||||
HunyuanVideoFramepackPipeline,
|
||||
HunyuanVideoImageToVideoPipeline,
|
||||
HunyuanVideoPipeline,
|
||||
|
||||
@@ -39,6 +39,7 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"]
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuanimage"] = ["AutoencoderKLHunyuanImage"]
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"]
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuanvideo15"] = ["AutoencoderKLHunyuanVideo15"]
|
||||
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
|
||||
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
|
||||
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
|
||||
@@ -96,6 +97,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_flux2"] = ["Flux2Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"]
|
||||
_import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"]
|
||||
@@ -147,6 +149,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLHunyuanImage,
|
||||
AutoencoderKLHunyuanImageRefiner,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLHunyuanVideo15,
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLMagvit,
|
||||
AutoencoderKLMochi,
|
||||
@@ -199,6 +202,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
HiDreamImageTransformer2DModel,
|
||||
HunyuanDiT2DModel,
|
||||
HunyuanImageTransformer2DModel,
|
||||
HunyuanVideo15Transformer3DModel,
|
||||
HunyuanVideoFramepackTransformer3DModel,
|
||||
HunyuanVideoTransformer3DModel,
|
||||
Kandinsky5Transformer3DModel,
|
||||
|
||||
@@ -282,6 +282,7 @@ def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBacke
|
||||
|
||||
backend = AttentionBackendName(backend)
|
||||
_check_attention_backend_requirements(backend)
|
||||
_maybe_download_kernel_for_backend(backend)
|
||||
|
||||
old_backend = _AttentionBackendRegistry._active_backend
|
||||
_AttentionBackendRegistry._active_backend = backend
|
||||
|
||||
@@ -8,6 +8,7 @@ from .autoencoder_kl_flux2 import AutoencoderKLFlux2
|
||||
from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
|
||||
from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage
|
||||
from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner
|
||||
from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15
|
||||
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
|
||||
from .autoencoder_kl_magvit import AutoencoderKLMagvit
|
||||
from .autoencoder_kl_mochi import AutoencoderKLMochi
|
||||
|
||||
@@ -0,0 +1,967 @@
|
||||
# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..activations import get_activation
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class HunyuanVideo15CausalConv3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int, int]] = 3,
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
padding: Union[int, Tuple[int, int, int]] = 0,
|
||||
dilation: Union[int, Tuple[int, int, int]] = 1,
|
||||
bias: bool = True,
|
||||
pad_mode: str = "replicate",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
kernel_size = (kernel_size, kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
|
||||
|
||||
self.pad_mode = pad_mode
|
||||
self.time_causal_padding = (
|
||||
kernel_size[0] // 2,
|
||||
kernel_size[0] // 2,
|
||||
kernel_size[1] // 2,
|
||||
kernel_size[1] // 2,
|
||||
kernel_size[2] - 1,
|
||||
0,
|
||||
)
|
||||
|
||||
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = F.pad(hidden_states, self.time_causal_padding, mode=self.pad_mode)
|
||||
return self.conv(hidden_states)
|
||||
|
||||
|
||||
class HunyuanVideo15RMS_norm(nn.Module):
|
||||
r"""
|
||||
A custom RMS normalization layer.
|
||||
|
||||
Args:
|
||||
dim (int): The number of dimensions to normalize over.
|
||||
channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
|
||||
Default is True.
|
||||
images (bool, optional): Whether the input represents image data. Default is True.
|
||||
bias (bool, optional): Whether to include a learnable bias term. Default is False.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
|
||||
super().__init__()
|
||||
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
||||
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
||||
|
||||
self.channel_first = channel_first
|
||||
self.scale = dim**0.5
|
||||
self.gamma = nn.Parameter(torch.ones(shape))
|
||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
|
||||
|
||||
|
||||
class HunyuanVideo15AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = HunyuanVideo15RMS_norm(in_channels, images=False)
|
||||
|
||||
self.to_q = nn.Conv3d(in_channels, in_channels, kernel_size=1)
|
||||
self.to_k = nn.Conv3d(in_channels, in_channels, kernel_size=1)
|
||||
self.to_v = nn.Conv3d(in_channels, in_channels, kernel_size=1)
|
||||
self.proj_out = nn.Conv3d(in_channels, in_channels, kernel_size=1)
|
||||
|
||||
@staticmethod
|
||||
def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, device, batch_size: int = None):
|
||||
"""Prepare a causal attention mask for 3D videos.
|
||||
|
||||
Args:
|
||||
n_frame (int): Number of frames (temporal length).
|
||||
n_hw (int): Product of height and width.
|
||||
dtype: Desired mask dtype.
|
||||
device: Device for the mask.
|
||||
batch_size (int, optional): If set, expands for batch.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Causal attention mask.
|
||||
"""
|
||||
seq_len = n_frame * n_hw
|
||||
mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
|
||||
for i in range(seq_len):
|
||||
i_frame = i // n_hw
|
||||
mask[i, : (i_frame + 1) * n_hw] = 0
|
||||
if batch_size is not None:
|
||||
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
|
||||
return mask
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
identity = x
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
query = self.to_q(x)
|
||||
key = self.to_k(x)
|
||||
value = self.to_v(x)
|
||||
|
||||
batch_size, channels, frames, height, width = query.shape
|
||||
|
||||
query = query.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous()
|
||||
key = key.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous()
|
||||
value = value.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous()
|
||||
|
||||
attention_mask = self.prepare_causal_attention_mask(
|
||||
frames, height * width, query.dtype, query.device, batch_size=batch_size
|
||||
)
|
||||
|
||||
x = nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask)
|
||||
|
||||
# batch_size, 1, frames * height * width, channels
|
||||
|
||||
x = x.squeeze(1).reshape(batch_size, frames, height, width, channels).permute(0, 4, 1, 2, 3)
|
||||
x = self.proj_out(x)
|
||||
|
||||
return x + identity
|
||||
|
||||
|
||||
class HunyuanVideo15Upsample(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, add_temporal_upsample: bool = True):
|
||||
super().__init__()
|
||||
factor = 2 * 2 * 2 if add_temporal_upsample else 1 * 2 * 2
|
||||
self.conv = HunyuanVideo15CausalConv3d(in_channels, out_channels * factor, kernel_size=3)
|
||||
|
||||
self.add_temporal_upsample = add_temporal_upsample
|
||||
self.repeats = factor * out_channels // in_channels
|
||||
|
||||
@staticmethod
|
||||
def _dcae_upsample_rearrange(tensor, r1=1, r2=2, r3=2):
|
||||
"""
|
||||
Convert (b, r1*r2*r3*c, f, h, w) -> (b, c, r1*f, r2*h, r3*w)
|
||||
|
||||
Args:
|
||||
tensor: Input tensor of shape (b, r1*r2*r3*c, f, h, w)
|
||||
r1: temporal upsampling factor
|
||||
r2: height upsampling factor
|
||||
r3: width upsampling factor
|
||||
"""
|
||||
b, packed_c, f, h, w = tensor.shape
|
||||
factor = r1 * r2 * r3
|
||||
c = packed_c // factor
|
||||
|
||||
tensor = tensor.view(b, r1, r2, r3, c, f, h, w)
|
||||
tensor = tensor.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||
return tensor.reshape(b, c, f * r1, h * r2, w * r3)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
r1 = 2 if self.add_temporal_upsample else 1
|
||||
h = self.conv(x)
|
||||
if self.add_temporal_upsample:
|
||||
h_first = h[:, :, :1, :, :]
|
||||
h_first = self._dcae_upsample_rearrange(h_first, r1=1, r2=2, r3=2)
|
||||
h_first = h_first[:, : h_first.shape[1] // 2]
|
||||
h_next = h[:, :, 1:, :, :]
|
||||
h_next = self._dcae_upsample_rearrange(h_next, r1=r1, r2=2, r3=2)
|
||||
h = torch.cat([h_first, h_next], dim=2)
|
||||
|
||||
# shortcut computation
|
||||
x_first = x[:, :, :1, :, :]
|
||||
x_first = self._dcae_upsample_rearrange(x_first, r1=1, r2=2, r3=2)
|
||||
x_first = x_first.repeat_interleave(repeats=self.repeats // 2, dim=1)
|
||||
|
||||
x_next = x[:, :, 1:, :, :]
|
||||
x_next = self._dcae_upsample_rearrange(x_next, r1=r1, r2=2, r3=2)
|
||||
x_next = x_next.repeat_interleave(repeats=self.repeats, dim=1)
|
||||
shortcut = torch.cat([x_first, x_next], dim=2)
|
||||
|
||||
else:
|
||||
h = self._dcae_upsample_rearrange(h, r1=r1, r2=2, r3=2)
|
||||
shortcut = x.repeat_interleave(repeats=self.repeats, dim=1)
|
||||
shortcut = self._dcae_upsample_rearrange(shortcut, r1=r1, r2=2, r3=2)
|
||||
return h + shortcut
|
||||
|
||||
|
||||
class HunyuanVideo15Downsample(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, add_temporal_downsample: bool = True):
|
||||
super().__init__()
|
||||
factor = 2 * 2 * 2 if add_temporal_downsample else 1 * 2 * 2
|
||||
self.conv = HunyuanVideo15CausalConv3d(in_channels, out_channels // factor, kernel_size=3)
|
||||
|
||||
self.add_temporal_downsample = add_temporal_downsample
|
||||
self.group_size = factor * in_channels // out_channels
|
||||
|
||||
@staticmethod
|
||||
def _dcae_downsample_rearrange(tensor, r1=1, r2=2, r3=2):
|
||||
"""
|
||||
Convert (b, c, r1*f, r2*h, r3*w) -> (b, r1*r2*r3*c, f, h, w)
|
||||
|
||||
This packs spatial/temporal dimensions into channels (opposite of upsample)
|
||||
"""
|
||||
b, c, packed_f, packed_h, packed_w = tensor.shape
|
||||
f, h, w = packed_f // r1, packed_h // r2, packed_w // r3
|
||||
|
||||
tensor = tensor.view(b, c, f, r1, h, r2, w, r3)
|
||||
tensor = tensor.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||
return tensor.reshape(b, r1 * r2 * r3 * c, f, h, w)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
r1 = 2 if self.add_temporal_downsample else 1
|
||||
h = self.conv(x)
|
||||
if self.add_temporal_downsample:
|
||||
h_first = h[:, :, :1, :, :]
|
||||
h_first = self._dcae_downsample_rearrange(h_first, r1=1, r2=2, r3=2)
|
||||
h_first = torch.cat([h_first, h_first], dim=1)
|
||||
h_next = h[:, :, 1:, :, :]
|
||||
h_next = self._dcae_downsample_rearrange(h_next, r1=r1, r2=2, r3=2)
|
||||
h = torch.cat([h_first, h_next], dim=2)
|
||||
|
||||
# shortcut computation
|
||||
x_first = x[:, :, :1, :, :]
|
||||
x_first = self._dcae_downsample_rearrange(x_first, r1=1, r2=2, r3=2)
|
||||
B, C, T, H, W = x_first.shape
|
||||
x_first = x_first.view(B, h.shape[1], self.group_size // 2, T, H, W).mean(dim=2)
|
||||
x_next = x[:, :, 1:, :, :]
|
||||
x_next = self._dcae_downsample_rearrange(x_next, r1=r1, r2=2, r3=2)
|
||||
B, C, T, H, W = x_next.shape
|
||||
x_next = x_next.view(B, h.shape[1], self.group_size, T, H, W).mean(dim=2)
|
||||
shortcut = torch.cat([x_first, x_next], dim=2)
|
||||
else:
|
||||
h = self._dcae_downsample_rearrange(h, r1=r1, r2=2, r3=2)
|
||||
shortcut = self._dcae_downsample_rearrange(x, r1=r1, r2=2, r3=2)
|
||||
B, C, T, H, W = shortcut.shape
|
||||
shortcut = shortcut.view(B, h.shape[1], self.group_size, T, H, W).mean(dim=2)
|
||||
|
||||
return h + shortcut
|
||||
|
||||
|
||||
class HunyuanVideo15ResnetBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
non_linearity: str = "swish",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
out_channels = out_channels or in_channels
|
||||
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
self.norm1 = HunyuanVideo15RMS_norm(in_channels, images=False)
|
||||
self.conv1 = HunyuanVideo15CausalConv3d(in_channels, out_channels, kernel_size=3)
|
||||
|
||||
self.norm2 = HunyuanVideo15RMS_norm(out_channels, images=False)
|
||||
self.conv2 = HunyuanVideo15CausalConv3d(out_channels, out_channels, kernel_size=3)
|
||||
|
||||
self.conv_shortcut = None
|
||||
if in_channels != out_channels:
|
||||
self.conv_shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
if self.conv_shortcut is not None:
|
||||
residual = self.conv_shortcut(residual)
|
||||
|
||||
return hidden_states + residual
|
||||
|
||||
|
||||
class HunyuanVideo15MidBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
num_layers: int = 1,
|
||||
add_attention: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.add_attention = add_attention
|
||||
|
||||
# There is always at least one resnet
|
||||
resnets = [
|
||||
HunyuanVideo15ResnetBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
)
|
||||
]
|
||||
attentions = []
|
||||
|
||||
for _ in range(num_layers):
|
||||
if self.add_attention:
|
||||
attentions.append(HunyuanVideo15AttnBlock(in_channels))
|
||||
else:
|
||||
attentions.append(None)
|
||||
|
||||
resnets.append(
|
||||
HunyuanVideo15ResnetBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
)
|
||||
)
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.resnets[0](hidden_states)
|
||||
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
if attn is not None:
|
||||
hidden_states = attn(hidden_states)
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanVideo15DownBlock3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_layers: int = 1,
|
||||
downsample_out_channels: Optional[int] = None,
|
||||
add_temporal_downsample: int = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
resnets = []
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
HunyuanVideo15ResnetBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
)
|
||||
)
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if downsample_out_channels is not None:
|
||||
self.downsamplers = nn.ModuleList(
|
||||
[
|
||||
HunyuanVideo15Downsample(
|
||||
out_channels,
|
||||
out_channels=downsample_out_channels,
|
||||
add_temporal_downsample=add_temporal_downsample,
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.downsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanVideo15UpBlock3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_layers: int = 1,
|
||||
upsample_out_channels: Optional[int] = None,
|
||||
add_temporal_upsample: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
resnets = []
|
||||
|
||||
for i in range(num_layers):
|
||||
input_channels = in_channels if i == 0 else out_channels
|
||||
|
||||
resnets.append(
|
||||
HunyuanVideo15ResnetBlock(
|
||||
in_channels=input_channels,
|
||||
out_channels=out_channels,
|
||||
)
|
||||
)
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if upsample_out_channels is not None:
|
||||
self.upsamplers = nn.ModuleList(
|
||||
[
|
||||
HunyuanVideo15Upsample(
|
||||
out_channels,
|
||||
out_channels=upsample_out_channels,
|
||||
add_temporal_upsample=add_temporal_upsample,
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.upsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for resnet in self.resnets:
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
|
||||
|
||||
else:
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanVideo15Encoder3D(nn.Module):
|
||||
r"""
|
||||
3D vae encoder for HunyuanImageRefiner.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 64,
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024, 1024),
|
||||
layers_per_block: int = 2,
|
||||
temporal_compression_ratio: int = 4,
|
||||
spatial_compression_ratio: int = 16,
|
||||
downsample_match_channel: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.group_size = block_out_channels[-1] // self.out_channels
|
||||
|
||||
self.conv_in = HunyuanVideo15CausalConv3d(in_channels, block_out_channels[0], kernel_size=3)
|
||||
self.mid_block = None
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
|
||||
input_channel = block_out_channels[0]
|
||||
for i in range(len(block_out_channels)):
|
||||
add_spatial_downsample = i < np.log2(spatial_compression_ratio)
|
||||
output_channel = block_out_channels[i]
|
||||
if not add_spatial_downsample:
|
||||
down_block = HunyuanVideo15DownBlock3D(
|
||||
num_layers=layers_per_block,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
downsample_out_channels=None,
|
||||
add_temporal_downsample=False,
|
||||
)
|
||||
input_channel = output_channel
|
||||
else:
|
||||
add_temporal_downsample = i >= np.log2(spatial_compression_ratio // temporal_compression_ratio)
|
||||
downsample_out_channels = block_out_channels[i + 1] if downsample_match_channel else output_channel
|
||||
down_block = HunyuanVideo15DownBlock3D(
|
||||
num_layers=layers_per_block,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
downsample_out_channels=downsample_out_channels,
|
||||
add_temporal_downsample=add_temporal_downsample,
|
||||
)
|
||||
input_channel = downsample_out_channels
|
||||
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
self.mid_block = HunyuanVideo15MidBlock(in_channels=block_out_channels[-1])
|
||||
|
||||
self.norm_out = HunyuanVideo15RMS_norm(block_out_channels[-1], images=False)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = HunyuanVideo15CausalConv3d(block_out_channels[-1], out_channels, kernel_size=3)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for down_block in self.down_blocks:
|
||||
hidden_states = self._gradient_checkpointing_func(down_block, hidden_states)
|
||||
|
||||
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
|
||||
else:
|
||||
for down_block in self.down_blocks:
|
||||
hidden_states = down_block(hidden_states)
|
||||
|
||||
hidden_states = self.mid_block(hidden_states)
|
||||
|
||||
batch_size, _, frame, height, width = hidden_states.shape
|
||||
short_cut = hidden_states.view(batch_size, -1, self.group_size, frame, height, width).mean(dim=2)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
hidden_states += short_cut
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanVideo15Decoder3D(nn.Module):
|
||||
r"""
|
||||
Causal decoder for 3D video-like data used for HunyuanImage-1.5 Refiner.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 32,
|
||||
out_channels: int = 3,
|
||||
block_out_channels: Tuple[int, ...] = (1024, 1024, 512, 256, 128),
|
||||
layers_per_block: int = 2,
|
||||
spatial_compression_ratio: int = 16,
|
||||
temporal_compression_ratio: int = 4,
|
||||
upsample_match_channel: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers_per_block = layers_per_block
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.repeat = block_out_channels[0] // self.in_channels
|
||||
|
||||
self.conv_in = HunyuanVideo15CausalConv3d(self.in_channels, block_out_channels[0], kernel_size=3)
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
# mid
|
||||
self.mid_block = HunyuanVideo15MidBlock(in_channels=block_out_channels[0])
|
||||
|
||||
# up
|
||||
input_channel = block_out_channels[0]
|
||||
for i in range(len(block_out_channels)):
|
||||
output_channel = block_out_channels[i]
|
||||
|
||||
add_spatial_upsample = i < np.log2(spatial_compression_ratio)
|
||||
add_temporal_upsample = i < np.log2(temporal_compression_ratio)
|
||||
if add_spatial_upsample or add_temporal_upsample:
|
||||
upsample_out_channels = block_out_channels[i + 1] if upsample_match_channel else output_channel
|
||||
up_block = HunyuanVideo15UpBlock3D(
|
||||
num_layers=self.layers_per_block + 1,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
upsample_out_channels=upsample_out_channels,
|
||||
add_temporal_upsample=add_temporal_upsample,
|
||||
)
|
||||
input_channel = upsample_out_channels
|
||||
else:
|
||||
up_block = HunyuanVideo15UpBlock3D(
|
||||
num_layers=self.layers_per_block + 1,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
upsample_out_channels=None,
|
||||
add_temporal_upsample=False,
|
||||
)
|
||||
input_channel = output_channel
|
||||
|
||||
self.up_blocks.append(up_block)
|
||||
|
||||
# out
|
||||
self.norm_out = HunyuanVideo15RMS_norm(block_out_channels[-1], images=False)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = HunyuanVideo15CausalConv3d(block_out_channels[-1], out_channels, kernel_size=3)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.conv_in(hidden_states) + hidden_states.repeat_interleave(repeats=self.repeat, dim=1)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
|
||||
|
||||
for up_block in self.up_blocks:
|
||||
hidden_states = self._gradient_checkpointing_func(up_block, hidden_states)
|
||||
else:
|
||||
hidden_states = self.mid_block(hidden_states)
|
||||
|
||||
for up_block in self.up_blocks:
|
||||
hidden_states = up_block(hidden_states)
|
||||
|
||||
# post-process
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AutoencoderKLHunyuanVideo15(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for
|
||||
HunyuanVideo-1.5.
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
||||
for all models (such as downloading or saving).
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
latent_channels: int = 32,
|
||||
block_out_channels: Tuple[int] = (128, 256, 512, 1024, 1024),
|
||||
layers_per_block: int = 2,
|
||||
spatial_compression_ratio: int = 16,
|
||||
temporal_compression_ratio: int = 4,
|
||||
downsample_match_channel: bool = True,
|
||||
upsample_match_channel: bool = True,
|
||||
scaling_factor: float = 1.03682,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.encoder = HunyuanVideo15Encoder3D(
|
||||
in_channels=in_channels,
|
||||
out_channels=latent_channels * 2,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
temporal_compression_ratio=temporal_compression_ratio,
|
||||
spatial_compression_ratio=spatial_compression_ratio,
|
||||
downsample_match_channel=downsample_match_channel,
|
||||
)
|
||||
|
||||
self.decoder = HunyuanVideo15Decoder3D(
|
||||
in_channels=latent_channels,
|
||||
out_channels=out_channels,
|
||||
block_out_channels=list(reversed(block_out_channels)),
|
||||
layers_per_block=layers_per_block,
|
||||
temporal_compression_ratio=temporal_compression_ratio,
|
||||
spatial_compression_ratio=spatial_compression_ratio,
|
||||
upsample_match_channel=upsample_match_channel,
|
||||
)
|
||||
|
||||
self.spatial_compression_ratio = spatial_compression_ratio
|
||||
self.temporal_compression_ratio = temporal_compression_ratio
|
||||
|
||||
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
|
||||
# to perform decoding of a single video latent at a time.
|
||||
self.use_slicing = False
|
||||
|
||||
# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
|
||||
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
|
||||
# intermediate tiles together, the memory requirement can be lowered.
|
||||
self.use_tiling = False
|
||||
|
||||
# The minimal tile height and width for spatial tiling to be used
|
||||
self.tile_sample_min_height = 256
|
||||
self.tile_sample_min_width = 256
|
||||
|
||||
# The minimal tile height and width in latent space
|
||||
self.tile_latent_min_height = self.tile_sample_min_height // spatial_compression_ratio
|
||||
self.tile_latent_min_width = self.tile_sample_min_width // spatial_compression_ratio
|
||||
self.tile_overlap_factor = 0.25
|
||||
|
||||
def enable_tiling(
|
||||
self,
|
||||
tile_sample_min_height: Optional[int] = None,
|
||||
tile_sample_min_width: Optional[int] = None,
|
||||
tile_latent_min_height: Optional[int] = None,
|
||||
tile_latent_min_width: Optional[int] = None,
|
||||
tile_overlap_factor: Optional[float] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
|
||||
Args:
|
||||
tile_sample_min_height (`int`, *optional*):
|
||||
The minimum height required for a sample to be separated into tiles across the height dimension.
|
||||
tile_sample_min_width (`int`, *optional*):
|
||||
The minimum width required for a sample to be separated into tiles across the width dimension.
|
||||
tile_latent_min_height (`int`, *optional*):
|
||||
The minimum height required for a latent to be separated into tiles across the height dimension.
|
||||
tile_latent_min_width (`int`, *optional*):
|
||||
The minimum width required for a latent to be separated into tiles across the width dimension.
|
||||
"""
|
||||
self.use_tiling = True
|
||||
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
|
||||
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
|
||||
self.tile_latent_min_height = tile_latent_min_height or self.tile_latent_min_height
|
||||
self.tile_latent_min_width = tile_latent_min_width or self.tile_latent_min_width
|
||||
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
|
||||
|
||||
def disable_tiling(self) -> None:
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_tiling = False
|
||||
|
||||
def enable_slicing(self) -> None:
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.use_slicing = True
|
||||
|
||||
def disable_slicing(self) -> None:
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_slicing = False
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
_, _, _, height, width = x.shape
|
||||
|
||||
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
|
||||
return self.tiled_encode(x)
|
||||
|
||||
x = self.encoder(x)
|
||||
return x
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
r"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`): Input batch of images.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
The latent representations of the encoded videos. If `return_dict` is True, a
|
||||
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
||||
"""
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
||||
h = torch.cat(encoded_slices)
|
||||
else:
|
||||
h = self._encode(x)
|
||||
|
||||
posterior = DiagonalGaussianDistribution(h)
|
||||
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor) -> torch.Tensor:
|
||||
_, _, _, height, width = z.shape
|
||||
|
||||
if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
|
||||
return self.tiled_decode(z)
|
||||
|
||||
dec = self.decoder(z)
|
||||
|
||||
return dec
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
r"""
|
||||
Decode a batch of images.
|
||||
|
||||
Args:
|
||||
z (`torch.Tensor`): Input batch of latent vectors.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
"""
|
||||
if self.use_slicing and z.shape[0] > 1:
|
||||
decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
else:
|
||||
decoded = self._decode(z)
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
|
||||
for y in range(blend_extent):
|
||||
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
|
||||
y / blend_extent
|
||||
)
|
||||
return b
|
||||
|
||||
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
|
||||
for x in range(blend_extent):
|
||||
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
|
||||
x / blend_extent
|
||||
)
|
||||
return b
|
||||
|
||||
def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
|
||||
for x in range(blend_extent):
|
||||
b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (
|
||||
x / blend_extent
|
||||
)
|
||||
return b
|
||||
|
||||
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
r"""Encode a batch of images using a tiled encoder.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`): Input batch of videos.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The latent representation of the encoded videos.
|
||||
"""
|
||||
_, _, _, height, width = x.shape
|
||||
|
||||
overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192
|
||||
overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192
|
||||
blend_height = int(self.tile_latent_min_height * self.tile_overlap_factor) # 8 * 0.25 = 2
|
||||
blend_width = int(self.tile_latent_min_width * self.tile_overlap_factor) # 8 * 0.25 = 2
|
||||
row_limit_height = self.tile_latent_min_height - blend_height # 8 - 2 = 6
|
||||
row_limit_width = self.tile_latent_min_width - blend_width # 8 - 2 = 6
|
||||
|
||||
rows = []
|
||||
for i in range(0, height, overlap_height):
|
||||
row = []
|
||||
for j in range(0, width, overlap_width):
|
||||
tile = x[
|
||||
:,
|
||||
:,
|
||||
:,
|
||||
i : i + self.tile_sample_min_height,
|
||||
j : j + self.tile_sample_min_width,
|
||||
]
|
||||
tile = self.encoder(tile)
|
||||
row.append(tile)
|
||||
rows.append(row)
|
||||
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_width)
|
||||
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
|
||||
result_rows.append(torch.cat(result_row, dim=-1))
|
||||
moments = torch.cat(result_rows, dim=-2)
|
||||
|
||||
return moments
|
||||
|
||||
def tiled_decode(self, z: torch.Tensor) -> torch.Tensor:
|
||||
r"""
|
||||
Decode a batch of images using a tiled decoder.
|
||||
|
||||
Args:
|
||||
z (`torch.Tensor`): Input batch of latent vectors.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
"""
|
||||
|
||||
_, _, _, height, width = z.shape
|
||||
|
||||
overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6
|
||||
overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6
|
||||
blend_height = int(self.tile_sample_min_height * self.tile_overlap_factor) # 256 * 0.25 = 64
|
||||
blend_width = int(self.tile_sample_min_width * self.tile_overlap_factor) # 256 * 0.25 = 64
|
||||
row_limit_height = self.tile_sample_min_height - blend_height # 256 - 64 = 192
|
||||
row_limit_width = self.tile_sample_min_width - blend_width # 256 - 64 = 192
|
||||
|
||||
rows = []
|
||||
for i in range(0, height, overlap_height):
|
||||
row = []
|
||||
for j in range(0, width, overlap_width):
|
||||
tile = z[
|
||||
:,
|
||||
:,
|
||||
:,
|
||||
i : i + self.tile_latent_min_height,
|
||||
j : j + self.tile_latent_min_width,
|
||||
]
|
||||
decoded = self.decoder(tile)
|
||||
row.append(decoded)
|
||||
rows.append(row)
|
||||
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_width)
|
||||
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
|
||||
result_rows.append(torch.cat(result_row, dim=-1))
|
||||
dec = torch.cat(result_rows, dim=-2)
|
||||
|
||||
return dec
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.Tensor`): Input sample.
|
||||
sample_posterior (`bool`, *optional*, defaults to `False`):
|
||||
Whether to sample from the posterior.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
||||
"""
|
||||
x = sample
|
||||
posterior = self.encode(x).latent_dist
|
||||
if sample_posterior:
|
||||
z = posterior.sample(generator=generator)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z, return_dict=return_dict)
|
||||
return dec
|
||||
@@ -29,6 +29,7 @@ if is_torch_available():
|
||||
from .transformer_flux2 import Flux2Transformer2DModel
|
||||
from .transformer_hidream_image import HiDreamImageTransformer2DModel
|
||||
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
|
||||
from .transformer_hunyuan_video15 import HunyuanVideo15Transformer3DModel
|
||||
from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel
|
||||
from .transformer_hunyuanimage import HunyuanImageTransformer2DModel
|
||||
from .transformer_kandinsky import Kandinsky5Transformer3DModel
|
||||
|
||||
836
src/diffusers/models/transformers/transformer_hunyuan_video15.py
Normal file
836
src/diffusers/models/transformers/transformer_hunyuan_video15.py
Normal file
@@ -0,0 +1,836 @@
|
||||
# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from diffusers.loaders import FromOriginalModelMixin
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..attention_processor import Attention, AttentionProcessor
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import (
|
||||
CombinedTimestepTextProjEmbeddings,
|
||||
TimestepEmbedding,
|
||||
Timesteps,
|
||||
get_1d_rotary_pos_embed,
|
||||
)
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class HunyuanVideo15AttnProcessor2_0:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"HunyuanVideo15AttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# 1. QKV projections
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1))
|
||||
key = key.unflatten(2, (attn.heads, -1))
|
||||
value = value.unflatten(2, (attn.heads, -1))
|
||||
|
||||
# 2. QK normalization
|
||||
query = attn.norm_q(query)
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# 3. Rotational positional embeddings applied to latent stream
|
||||
if image_rotary_emb is not None:
|
||||
from ..embeddings import apply_rotary_emb
|
||||
|
||||
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
||||
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
||||
|
||||
# 4. Encoder condition QKV projection and normalization
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
|
||||
encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
|
||||
encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_query = attn.norm_added_q(encoder_query)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_key = attn.norm_added_k(encoder_key)
|
||||
|
||||
query = torch.cat([query, encoder_query], dim=1)
|
||||
key = torch.cat([key, encoder_key], dim=1)
|
||||
value = torch.cat([value, encoder_value], dim=1)
|
||||
|
||||
batch_size, seq_len, heads, dim = query.shape
|
||||
attention_mask = F.pad(attention_mask, (seq_len - attention_mask.shape[1], 0), value=True)
|
||||
attention_mask = attention_mask.bool()
|
||||
self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
|
||||
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
|
||||
attention_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
|
||||
|
||||
# 5. Attention
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# 6. Output projection
|
||||
if encoder_hidden_states is not None:
|
||||
hidden_states, encoder_hidden_states = (
|
||||
hidden_states[:, : -encoder_hidden_states.shape[1]],
|
||||
hidden_states[:, -encoder_hidden_states.shape[1] :],
|
||||
)
|
||||
|
||||
if getattr(attn, "to_out", None) is not None:
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if getattr(attn, "to_add_out", None) is not None:
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class HunyuanVideo15PatchEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: Union[int, Tuple[int, int, int]] = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
patch_size = (patch_size, patch_size, patch_size) if isinstance(patch_size, int) else patch_size
|
||||
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.proj(hidden_states)
|
||||
hidden_states = hidden_states.flatten(2).transpose(1, 2) # BCFHW -> BNC
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanVideo15AdaNorm(nn.Module):
|
||||
def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
|
||||
super().__init__()
|
||||
|
||||
out_features = out_features or 2 * in_features
|
||||
self.linear = nn.Linear(in_features, out_features)
|
||||
self.nonlinearity = nn.SiLU()
|
||||
|
||||
def forward(
|
||||
self, temb: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
temb = self.linear(self.nonlinearity(temb))
|
||||
gate_msa, gate_mlp = temb.chunk(2, dim=1)
|
||||
gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
|
||||
return gate_msa, gate_mlp
|
||||
|
||||
|
||||
class HunyuanVideo15TimeEmbedding(nn.Module):
|
||||
r"""
|
||||
Time embedding for HunyuanVideo 1.5.
|
||||
|
||||
Supports standard timestep embedding and optional reference timestep embedding for MeanFlow-based super-resolution
|
||||
models.
|
||||
|
||||
Args:
|
||||
embedding_dim (`int`):
|
||||
The dimension of the output embedding.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim: int):
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
timestep: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=timestep.dtype))
|
||||
|
||||
return timesteps_emb
|
||||
|
||||
|
||||
class HunyuanVideo15IndividualTokenRefinerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
mlp_width_ratio: str = 4.0,
|
||||
mlp_drop_rate: float = 0.0,
|
||||
attention_bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
hidden_size = num_attention_heads * attention_head_dim
|
||||
|
||||
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
|
||||
self.attn = Attention(
|
||||
query_dim=hidden_size,
|
||||
cross_attention_dim=None,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
bias=attention_bias,
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
|
||||
self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate)
|
||||
|
||||
self.norm_out = HunyuanVideo15AdaNorm(hidden_size, 2 * hidden_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
|
||||
attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
gate_msa, gate_mlp = self.norm_out(temb)
|
||||
hidden_states = hidden_states + attn_output * gate_msa
|
||||
|
||||
ff_output = self.ff(self.norm2(hidden_states))
|
||||
hidden_states = hidden_states + ff_output * gate_mlp
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanVideo15IndividualTokenRefiner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
num_layers: int,
|
||||
mlp_width_ratio: float = 4.0,
|
||||
mlp_drop_rate: float = 0.0,
|
||||
attention_bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.refiner_blocks = nn.ModuleList(
|
||||
[
|
||||
HunyuanVideo15IndividualTokenRefinerBlock(
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
mlp_width_ratio=mlp_width_ratio,
|
||||
mlp_drop_rate=mlp_drop_rate,
|
||||
attention_bias=attention_bias,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
self_attn_mask = None
|
||||
if attention_mask is not None:
|
||||
batch_size = attention_mask.shape[0]
|
||||
seq_len = attention_mask.shape[1]
|
||||
attention_mask = attention_mask.to(hidden_states.device).bool()
|
||||
self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
|
||||
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
|
||||
self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
|
||||
|
||||
for block in self.refiner_blocks:
|
||||
hidden_states = block(hidden_states, temb, self_attn_mask)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanVideo15TokenRefiner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
num_layers: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
mlp_drop_rate: float = 0.0,
|
||||
attention_bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
hidden_size = num_attention_heads * attention_head_dim
|
||||
|
||||
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
|
||||
embedding_dim=hidden_size, pooled_projection_dim=in_channels
|
||||
)
|
||||
self.proj_in = nn.Linear(in_channels, hidden_size, bias=True)
|
||||
self.token_refiner = HunyuanVideo15IndividualTokenRefiner(
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_layers=num_layers,
|
||||
mlp_width_ratio=mlp_ratio,
|
||||
mlp_drop_rate=mlp_drop_rate,
|
||||
attention_bias=attention_bias,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if attention_mask is None:
|
||||
pooled_projections = hidden_states.mean(dim=1)
|
||||
else:
|
||||
original_dtype = hidden_states.dtype
|
||||
mask_float = attention_mask.float().unsqueeze(-1)
|
||||
pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
|
||||
pooled_projections = pooled_projections.to(original_dtype)
|
||||
|
||||
temb = self.time_text_embed(timestep, pooled_projections)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
hidden_states = self.token_refiner(hidden_states, temb, attention_mask)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanVideo15RotaryPosEmbed(nn.Module):
|
||||
def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], theta: float = 256.0) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.patch_size_t = patch_size_t
|
||||
self.rope_dim = rope_dim
|
||||
self.theta = theta
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
rope_sizes = [num_frames // self.patch_size_t, height // self.patch_size, width // self.patch_size]
|
||||
|
||||
axes_grids = []
|
||||
for i in range(len(rope_sizes)):
|
||||
# Note: The following line diverges from original behaviour. We create the grid on the device, whereas
|
||||
# original implementation creates it on CPU and then moves it to device. This results in numerical
|
||||
# differences in layerwise debugging outputs, but visually it is the same.
|
||||
grid = torch.arange(0, rope_sizes[i], device=hidden_states.device, dtype=torch.float32)
|
||||
axes_grids.append(grid)
|
||||
grid = torch.meshgrid(*axes_grids, indexing="ij") # [W, H, T]
|
||||
grid = torch.stack(grid, dim=0) # [3, W, H, T]
|
||||
|
||||
freqs = []
|
||||
for i in range(3):
|
||||
freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True)
|
||||
freqs.append(freq)
|
||||
|
||||
freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2)
|
||||
freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2)
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
class HunyuanVideo15ByT5TextProjection(nn.Module):
|
||||
def __init__(self, in_features: int, hidden_size: int, out_features: int):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(in_features)
|
||||
self.linear_1 = nn.Linear(in_features, hidden_size)
|
||||
self.linear_2 = nn.Linear(hidden_size, hidden_size)
|
||||
self.linear_3 = nn.Linear(hidden_size, out_features)
|
||||
self.act_fn = nn.GELU()
|
||||
|
||||
def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.norm(encoder_hidden_states)
|
||||
hidden_states = self.linear_1(hidden_states)
|
||||
hidden_states = self.act_fn(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
hidden_states = self.act_fn(hidden_states)
|
||||
hidden_states = self.linear_3(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanVideo15ImageProjection(nn.Module):
|
||||
def __init__(self, in_channels: int, hidden_size: int):
|
||||
super().__init__()
|
||||
self.norm_in = nn.LayerNorm(in_channels)
|
||||
self.linear_1 = nn.Linear(in_channels, in_channels)
|
||||
self.act_fn = nn.GELU()
|
||||
self.linear_2 = nn.Linear(in_channels, hidden_size)
|
||||
self.norm_out = nn.LayerNorm(hidden_size)
|
||||
|
||||
def forward(self, image_embeds: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.norm_in(image_embeds)
|
||||
hidden_states = self.linear_1(hidden_states)
|
||||
hidden_states = self.act_fn(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanVideo15TransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
mlp_ratio: float,
|
||||
qk_norm: str = "rms_norm",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
hidden_size = num_attention_heads * attention_head_dim
|
||||
|
||||
self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
|
||||
self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
|
||||
|
||||
self.attn = Attention(
|
||||
query_dim=hidden_size,
|
||||
cross_attention_dim=None,
|
||||
added_kv_proj_dim=hidden_size,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=hidden_size,
|
||||
context_pre_only=False,
|
||||
bias=True,
|
||||
processor=HunyuanVideo15AttnProcessor2_0(),
|
||||
qk_norm=qk_norm,
|
||||
eps=1e-6,
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
|
||||
|
||||
self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Input normalization
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
||||
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
||||
encoder_hidden_states, emb=temb
|
||||
)
|
||||
|
||||
# 2. Joint attention
|
||||
attn_output, context_attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=freqs_cis,
|
||||
)
|
||||
|
||||
# 3. Modulation and residual connection
|
||||
hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1)
|
||||
encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
|
||||
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
||||
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
||||
|
||||
# 4. Feed-forward
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
|
||||
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class HunyuanVideo15Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
|
||||
r"""
|
||||
A Transformer model for video-like data used in [HunyuanVideo1.5](https://huggingface.co/tencent/HunyuanVideo1.5).
|
||||
|
||||
Args:
|
||||
in_channels (`int`, defaults to `16`):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, defaults to `16`):
|
||||
The number of channels in the output.
|
||||
num_attention_heads (`int`, defaults to `24`):
|
||||
The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, defaults to `128`):
|
||||
The number of channels in each head.
|
||||
num_layers (`int`, defaults to `20`):
|
||||
The number of layers of dual-stream blocks to use.
|
||||
num_refiner_layers (`int`, defaults to `2`):
|
||||
The number of layers of refiner blocks to use.
|
||||
mlp_ratio (`float`, defaults to `4.0`):
|
||||
The ratio of the hidden layer size to the input size in the feedforward network.
|
||||
patch_size (`int`, defaults to `2`):
|
||||
The size of the spatial patches to use in the patch embedding layer.
|
||||
patch_size_t (`int`, defaults to `1`):
|
||||
The size of the tmeporal patches to use in the patch embedding layer.
|
||||
qk_norm (`str`, defaults to `rms_norm`):
|
||||
The normalization to use for the query and key projections in the attention layers.
|
||||
guidance_embeds (`bool`, defaults to `True`):
|
||||
Whether to use guidance embeddings in the model.
|
||||
text_embed_dim (`int`, defaults to `4096`):
|
||||
Input dimension of text embeddings from the text encoder.
|
||||
pooled_projection_dim (`int`, defaults to `768`):
|
||||
The dimension of the pooled projection of the text embeddings.
|
||||
rope_theta (`float`, defaults to `256.0`):
|
||||
The value of theta to use in the RoPE layer.
|
||||
rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
|
||||
The dimensions of the axes to use in the RoPE layer.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"]
|
||||
_no_split_modules = [
|
||||
"HunyuanVideo15TransformerBlock",
|
||||
"HunyuanVideo15PatchEmbed",
|
||||
"HunyuanVideo15TokenRefiner",
|
||||
]
|
||||
_repeated_blocks = [
|
||||
"HunyuanVideo15TransformerBlock",
|
||||
"HunyuanVideo15PatchEmbed",
|
||||
"HunyuanVideo15TokenRefiner",
|
||||
]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 65,
|
||||
out_channels: int = 32,
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 128,
|
||||
num_layers: int = 54,
|
||||
num_refiner_layers: int = 2,
|
||||
mlp_ratio: float = 4.0,
|
||||
patch_size: int = 1,
|
||||
patch_size_t: int = 1,
|
||||
qk_norm: str = "rms_norm",
|
||||
text_embed_dim: int = 3584,
|
||||
text_embed_2_dim: int = 1472,
|
||||
image_embed_dim: int = 1152,
|
||||
rope_theta: float = 256.0,
|
||||
rope_axes_dim: Tuple[int, ...] = (16, 56, 56),
|
||||
# YiYi Notes: config based on target_size_config https://github.com/yiyixuxu/hy15/blob/main/hyvideo/pipelines/hunyuan_video_pipeline.py#L205
|
||||
target_size: int = 640, # did not name sample_size since it is in pixel spaces
|
||||
task_type: str = "i2v",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
out_channels = out_channels or in_channels
|
||||
|
||||
# 1. Latent and condition embedders
|
||||
self.x_embedder = HunyuanVideo15PatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
|
||||
self.image_embedder = HunyuanVideo15ImageProjection(image_embed_dim, inner_dim)
|
||||
|
||||
self.context_embedder = HunyuanVideo15TokenRefiner(
|
||||
text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
|
||||
)
|
||||
self.context_embedder_2 = HunyuanVideo15ByT5TextProjection(text_embed_2_dim, 2048, inner_dim)
|
||||
|
||||
self.time_embed = HunyuanVideo15TimeEmbedding(inner_dim)
|
||||
|
||||
self.cond_type_embed = nn.Embedding(3, inner_dim)
|
||||
|
||||
# 2. RoPE
|
||||
self.rope = HunyuanVideo15RotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta)
|
||||
|
||||
# 3. Dual stream transformer blocks
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
HunyuanVideo15TransformerBlock(
|
||||
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 5. Output projection
|
||||
self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_attention_mask: torch.Tensor,
|
||||
encoder_hidden_states_2: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask_2: Optional[torch.Tensor] = None,
|
||||
image_embeds: Optional[torch.Tensor] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p_t, p_h, p_w = self.config.patch_size_t, self.config.patch_size, self.config.patch_size
|
||||
post_patch_num_frames = num_frames // p_t
|
||||
post_patch_height = height // p_h
|
||||
post_patch_width = width // p_w
|
||||
|
||||
# 1. RoPE
|
||||
image_rotary_emb = self.rope(hidden_states)
|
||||
|
||||
# 2. Conditional embeddings
|
||||
temb = self.time_embed(timestep)
|
||||
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
# qwen text embedding
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
|
||||
|
||||
encoder_hidden_states_cond_emb = self.cond_type_embed(
|
||||
torch.zeros_like(encoder_hidden_states[:, :, 0], dtype=torch.long)
|
||||
)
|
||||
encoder_hidden_states = encoder_hidden_states + encoder_hidden_states_cond_emb
|
||||
|
||||
# byt5 text embedding
|
||||
encoder_hidden_states_2 = self.context_embedder_2(encoder_hidden_states_2)
|
||||
|
||||
encoder_hidden_states_2_cond_emb = self.cond_type_embed(
|
||||
torch.ones_like(encoder_hidden_states_2[:, :, 0], dtype=torch.long)
|
||||
)
|
||||
encoder_hidden_states_2 = encoder_hidden_states_2 + encoder_hidden_states_2_cond_emb
|
||||
|
||||
# image embed
|
||||
encoder_hidden_states_3 = self.image_embedder(image_embeds)
|
||||
is_t2v = torch.all(image_embeds == 0)
|
||||
if is_t2v:
|
||||
encoder_hidden_states_3 = encoder_hidden_states_3 * 0.0
|
||||
encoder_attention_mask_3 = torch.zeros(
|
||||
(batch_size, encoder_hidden_states_3.shape[1]),
|
||||
dtype=encoder_attention_mask.dtype,
|
||||
device=encoder_attention_mask.device,
|
||||
)
|
||||
else:
|
||||
encoder_attention_mask_3 = torch.ones(
|
||||
(batch_size, encoder_hidden_states_3.shape[1]),
|
||||
dtype=encoder_attention_mask.dtype,
|
||||
device=encoder_attention_mask.device,
|
||||
)
|
||||
encoder_hidden_states_3_cond_emb = self.cond_type_embed(
|
||||
2
|
||||
* torch.ones_like(
|
||||
encoder_hidden_states_3[:, :, 0],
|
||||
dtype=torch.long,
|
||||
)
|
||||
)
|
||||
encoder_hidden_states_3 = encoder_hidden_states_3 + encoder_hidden_states_3_cond_emb
|
||||
|
||||
# reorder and combine text tokens: combine valid tokens first, then padding
|
||||
encoder_attention_mask = encoder_attention_mask.bool()
|
||||
encoder_attention_mask_2 = encoder_attention_mask_2.bool()
|
||||
encoder_attention_mask_3 = encoder_attention_mask_3.bool()
|
||||
new_encoder_hidden_states = []
|
||||
new_encoder_attention_mask = []
|
||||
|
||||
for text, text_mask, text_2, text_mask_2, image, image_mask in zip(
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
encoder_hidden_states_2,
|
||||
encoder_attention_mask_2,
|
||||
encoder_hidden_states_3,
|
||||
encoder_attention_mask_3,
|
||||
):
|
||||
# Concatenate: [valid_image, valid_byt5, valid_mllm, invalid_image, invalid_byt5, invalid_mllm]
|
||||
new_encoder_hidden_states.append(
|
||||
torch.cat(
|
||||
[
|
||||
image[image_mask], # valid image
|
||||
text_2[text_mask_2], # valid byt5
|
||||
text[text_mask], # valid mllm
|
||||
image[~image_mask], # invalid image
|
||||
torch.zeros_like(text_2[~text_mask_2]), # invalid byt5 (zeroed)
|
||||
torch.zeros_like(text[~text_mask]), # invalid mllm (zeroed)
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
)
|
||||
|
||||
# Apply same reordering to attention masks
|
||||
new_encoder_attention_mask.append(
|
||||
torch.cat(
|
||||
[
|
||||
image_mask[image_mask],
|
||||
text_mask_2[text_mask_2],
|
||||
text_mask[text_mask],
|
||||
image_mask[~image_mask],
|
||||
text_mask_2[~text_mask_2],
|
||||
text_mask[~text_mask],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
)
|
||||
|
||||
encoder_hidden_states = torch.stack(new_encoder_hidden_states)
|
||||
encoder_attention_mask = torch.stack(new_encoder_attention_mask)
|
||||
|
||||
# 4. Transformer blocks
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
encoder_attention_mask,
|
||||
image_rotary_emb,
|
||||
)
|
||||
|
||||
else:
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
encoder_attention_mask,
|
||||
image_rotary_emb,
|
||||
)
|
||||
|
||||
# 5. Output projection
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p_h, p_w
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
||||
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (hidden_states,)
|
||||
|
||||
return Transformer2DModelOutput(sample=hidden_states)
|
||||
@@ -243,6 +243,7 @@ else:
|
||||
"HunyuanVideoImageToVideoPipeline",
|
||||
"HunyuanVideoFramepackPipeline",
|
||||
]
|
||||
_import_structure["hunyuan_video1_5"] = ["HunyuanVideo15Pipeline", "HunyuanVideo15ImageToVideoPipeline"]
|
||||
_import_structure["hunyuan_image"] = ["HunyuanImagePipeline", "HunyuanImageRefinerPipeline"]
|
||||
_import_structure["kandinsky"] = [
|
||||
"KandinskyCombinedPipeline",
|
||||
@@ -665,6 +666,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
HunyuanVideoImageToVideoPipeline,
|
||||
HunyuanVideoPipeline,
|
||||
)
|
||||
from .hunyuan_video1_5 import HunyuanVideo15ImageToVideoPipeline, HunyuanVideo15Pipeline
|
||||
from .hunyuandit import HunyuanDiTPipeline
|
||||
from .i2vgen_xl import I2VGenXLPipeline
|
||||
from .kandinsky import (
|
||||
|
||||
50
src/diffusers/pipelines/hunyuan_video1_5/__init__.py
Normal file
50
src/diffusers/pipelines/hunyuan_video1_5/__init__.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_hunyuan_video1_5"] = ["HunyuanVideo15Pipeline"]
|
||||
_import_structure["pipeline_hunyuan_video1_5_image2video"] = ["HunyuanVideo15ImageToVideoPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_hunyuan_video1_5 import HunyuanVideo15Pipeline
|
||||
from .pipeline_hunyuan_video1_5_image2video import HunyuanVideo15ImageToVideoPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
103
src/diffusers/pipelines/hunyuan_video1_5/image_processor.py
Normal file
103
src/diffusers/pipelines/hunyuan_video1_5/image_processor.py
Normal file
@@ -0,0 +1,103 @@
|
||||
# Copyright 2025 The HunyuanVideo Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...configuration_utils import register_to_config
|
||||
from ...video_processor import VideoProcessor
|
||||
|
||||
|
||||
# copied from https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/main/hyvideo/utils/data_utils.py#L20
|
||||
def generate_crop_size_list(base_size=256, patch_size=16, max_ratio=4.0):
|
||||
num_patches = round((base_size / patch_size) ** 2)
|
||||
assert max_ratio >= 1.0
|
||||
crop_size_list = []
|
||||
wp, hp = num_patches, 1
|
||||
while wp > 0:
|
||||
if max(wp, hp) / min(wp, hp) <= max_ratio:
|
||||
crop_size_list.append((wp * patch_size, hp * patch_size))
|
||||
if (hp + 1) * wp <= num_patches:
|
||||
hp += 1
|
||||
else:
|
||||
wp -= 1
|
||||
return crop_size_list
|
||||
|
||||
|
||||
# copied from https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/main/hyvideo/utils/data_utils.py#L38
|
||||
def get_closest_ratio(height: float, width: float, ratios: list, buckets: list):
|
||||
"""
|
||||
Get the closest ratio in the buckets.
|
||||
|
||||
Args:
|
||||
height (float): video height
|
||||
width (float): video width
|
||||
ratios (list): video aspect ratio
|
||||
buckets (list): buckets generated by `generate_crop_size_list`
|
||||
|
||||
Returns:
|
||||
the closest size in the buckets and the corresponding ratio
|
||||
"""
|
||||
aspect_ratio = float(height) / float(width)
|
||||
diff_ratios = ratios - aspect_ratio
|
||||
|
||||
if aspect_ratio >= 1:
|
||||
indices = [(index, x) for index, x in enumerate(diff_ratios) if x <= 0]
|
||||
else:
|
||||
indices = [(index, x) for index, x in enumerate(diff_ratios) if x >= 0]
|
||||
|
||||
closest_ratio_id = min(indices, key=lambda pair: abs(pair[1]))[0]
|
||||
closest_size = buckets[closest_ratio_id]
|
||||
closest_ratio = ratios[closest_ratio_id]
|
||||
|
||||
return closest_size, closest_ratio
|
||||
|
||||
|
||||
class HunyuanVideo15ImageProcessor(VideoProcessor):
|
||||
r"""
|
||||
Image/video processor to preproces/postprocess the reference image/generatedvideo for the HunyuanVideo1.5 model.
|
||||
|
||||
Args:
|
||||
do_resize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
|
||||
`height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
|
||||
vae_scale_factor (`int`, *optional*, defaults to `16`):
|
||||
VAE (spatial) scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of
|
||||
this factor.
|
||||
vae_latent_channels (`int`, *optional*, defaults to `32`):
|
||||
VAE latent channels.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
||||
Whether to convert the image to RGB.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
do_resize: bool = True,
|
||||
vae_scale_factor: int = 16,
|
||||
vae_latent_channels: int = 32,
|
||||
do_convert_rgb: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
do_resize=do_resize,
|
||||
vae_scale_factor=vae_scale_factor,
|
||||
vae_latent_channels=vae_latent_channels,
|
||||
do_convert_rgb=do_convert_rgb,
|
||||
)
|
||||
|
||||
def calculate_default_height_width(self, height: int, width: int, target_size: int):
|
||||
crop_size_list = generate_crop_size_list(base_size=target_size, patch_size=self.config.vae_scale_factor)
|
||||
aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list])
|
||||
height, width = get_closest_ratio(height, width, aspect_ratios, crop_size_list)[0]
|
||||
|
||||
return height, width
|
||||
@@ -0,0 +1,837 @@
|
||||
# Copyright 2025 The HunyuanVideo Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import ByT5Tokenizer, Qwen2_5_VLTextModel, Qwen2Tokenizer, T5EncoderModel
|
||||
|
||||
from ...guiders import ClassifierFreeGuidance
|
||||
from ...models import AutoencoderKLHunyuanVideo15, HunyuanVideo15Transformer3DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .image_processor import HunyuanVideo15ImageProcessor
|
||||
from .pipeline_output import HunyuanVideo15PipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import HunyuanVideo15Pipeline
|
||||
>>> from diffusers.utils import export_to_video
|
||||
|
||||
>>> model_id = "hunyuanvideo-community/HunyuanVideo-1.5-480p_t2v"
|
||||
>>> pipe = HunyuanVideo15Pipeline.from_pretrained(model_id, torch_dtype=torch.float16)
|
||||
>>> pipe.vae.enable_tiling()
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> output = pipe(
|
||||
... prompt="A cat walks on the grass, realistic",
|
||||
... num_inference_steps=50,
|
||||
... ).frames[0]
|
||||
>>> export_to_video(output, "output.mp4", fps=15)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def format_text_input(prompt: List[str], system_message: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Apply text to template.
|
||||
|
||||
Args:
|
||||
prompt (List[str]): Input text.
|
||||
system_message (str): System message.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: List of chat conversation.
|
||||
"""
|
||||
|
||||
template = [
|
||||
[{"role": "system", "content": system_message}, {"role": "user", "content": p if p else " "}] for p in prompt
|
||||
]
|
||||
|
||||
return template
|
||||
|
||||
|
||||
def extract_glyph_texts(prompt: str) -> List[str]:
|
||||
"""
|
||||
Extract glyph texts from prompt using regex pattern.
|
||||
|
||||
Args:
|
||||
prompt: Input prompt string
|
||||
|
||||
Returns:
|
||||
List of extracted glyph texts
|
||||
"""
|
||||
pattern = r"\"(.*?)\"|“(.*?)”"
|
||||
matches = re.findall(pattern, prompt)
|
||||
result = [match[0] or match[1] for match in matches]
|
||||
result = list(dict.fromkeys(result)) if len(result) > 1 else result
|
||||
|
||||
if result:
|
||||
formatted_result = ". ".join([f'Text "{text}"' for text in result]) + ". "
|
||||
else:
|
||||
formatted_result = None
|
||||
|
||||
return formatted_result
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class HunyuanVideo15Pipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-video generation using HunyuanVideo1.5.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
Args:
|
||||
transformer ([`HunyuanVideo15Transformer3DModel`]):
|
||||
Conditional Transformer (MMDiT) architecture to denoise the encoded video latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
|
||||
vae ([`AutoencoderKLHunyuanVideo15`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
|
||||
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
|
||||
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
|
||||
tokenizer (`Qwen2Tokenizer`): Tokenizer of class [Qwen2Tokenizer].
|
||||
text_encoder_2 ([`T5EncoderModel`]):
|
||||
[T5EncoderModel](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel)
|
||||
variant.
|
||||
tokenizer_2 (`ByT5Tokenizer`): Tokenizer of class [ByT5Tokenizer]
|
||||
guider ([`ClassifierFreeGuidance`]):
|
||||
[ClassifierFreeGuidance]for classifier free guidance.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_encoder: Qwen2_5_VLTextModel,
|
||||
tokenizer: Qwen2Tokenizer,
|
||||
transformer: HunyuanVideo15Transformer3DModel,
|
||||
vae: AutoencoderKLHunyuanVideo15,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
text_encoder_2: T5EncoderModel,
|
||||
tokenizer_2: ByT5Tokenizer,
|
||||
guider: ClassifierFreeGuidance,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer_2=tokenizer_2,
|
||||
guider=guider,
|
||||
)
|
||||
|
||||
self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4
|
||||
self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 16
|
||||
self.video_processor = HunyuanVideo15ImageProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
self.target_size = self.transformer.config.target_size if getattr(self, "transformer", None) else 640
|
||||
self.vision_states_dim = (
|
||||
self.transformer.config.image_embed_dim if getattr(self, "transformer", None) else 1152
|
||||
)
|
||||
self.num_channels_latents = self.vae.config.latent_channels if hasattr(self, "vae") else 32
|
||||
# fmt: off
|
||||
self.system_message = "You are a helpful assistant. Describe the video by detailing the following aspects: \
|
||||
1. The main content and theme of the video. \
|
||||
2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \
|
||||
3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \
|
||||
4. background environment, light, style and atmosphere. \
|
||||
5. camera angles, movements, and transitions used in the video."
|
||||
# fmt: on
|
||||
self.prompt_template_encode_start_idx = 108
|
||||
self.tokenizer_max_length = 1000
|
||||
self.tokenizer_2_max_length = 256
|
||||
self.vision_num_semantic_tokens = 729
|
||||
self.default_aspect_ratio = (16, 9) # (width: height)
|
||||
|
||||
@staticmethod
|
||||
def _get_mllm_prompt_embeds(
|
||||
text_encoder: Qwen2_5_VLTextModel,
|
||||
tokenizer: Qwen2Tokenizer,
|
||||
prompt: Union[str, List[str]],
|
||||
device: torch.device,
|
||||
tokenizer_max_length: int = 1000,
|
||||
num_hidden_layers_to_skip: int = 2,
|
||||
# fmt: off
|
||||
system_message: str = "You are a helpful assistant. Describe the video by detailing the following aspects: \
|
||||
1. The main content and theme of the video. \
|
||||
2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \
|
||||
3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \
|
||||
4. background environment, light, style and atmosphere. \
|
||||
5. camera angles, movements, and transitions used in the video.",
|
||||
# fmt: on
|
||||
crop_start: int = 108,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
prompt = format_text_input(prompt, system_message)
|
||||
|
||||
text_inputs = tokenizer.apply_chat_template(
|
||||
prompt,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
padding="max_length",
|
||||
max_length=tokenizer_max_length + crop_start,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids.to(device=device)
|
||||
prompt_attention_mask = text_inputs.attention_mask.to(device=device)
|
||||
|
||||
prompt_embeds = text_encoder(
|
||||
input_ids=text_input_ids,
|
||||
attention_mask=prompt_attention_mask,
|
||||
output_hidden_states=True,
|
||||
).hidden_states[-(num_hidden_layers_to_skip + 1)]
|
||||
|
||||
if crop_start is not None and crop_start > 0:
|
||||
prompt_embeds = prompt_embeds[:, crop_start:]
|
||||
prompt_attention_mask = prompt_attention_mask[:, crop_start:]
|
||||
|
||||
return prompt_embeds, prompt_attention_mask
|
||||
|
||||
@staticmethod
|
||||
def _get_byt5_prompt_embeds(
|
||||
tokenizer: ByT5Tokenizer,
|
||||
text_encoder: T5EncoderModel,
|
||||
prompt: Union[str, List[str]],
|
||||
device: torch.device,
|
||||
tokenizer_max_length: int = 256,
|
||||
):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
glyph_texts = [extract_glyph_texts(p) for p in prompt]
|
||||
|
||||
prompt_embeds_list = []
|
||||
prompt_embeds_mask_list = []
|
||||
|
||||
for glyph_text in glyph_texts:
|
||||
if glyph_text is None:
|
||||
glyph_text_embeds = torch.zeros(
|
||||
(1, tokenizer_max_length, text_encoder.config.d_model), device=device, dtype=text_encoder.dtype
|
||||
)
|
||||
glyph_text_embeds_mask = torch.zeros((1, tokenizer_max_length), device=device, dtype=torch.int64)
|
||||
else:
|
||||
txt_tokens = tokenizer(
|
||||
glyph_text,
|
||||
padding="max_length",
|
||||
max_length=tokenizer_max_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
).to(device)
|
||||
|
||||
glyph_text_embeds = text_encoder(
|
||||
input_ids=txt_tokens.input_ids,
|
||||
attention_mask=txt_tokens.attention_mask.float(),
|
||||
)[0]
|
||||
glyph_text_embeds = glyph_text_embeds.to(device=device)
|
||||
glyph_text_embeds_mask = txt_tokens.attention_mask.to(device=device)
|
||||
|
||||
prompt_embeds_list.append(glyph_text_embeds)
|
||||
prompt_embeds_mask_list.append(glyph_text_embeds_mask)
|
||||
|
||||
prompt_embeds = torch.cat(prompt_embeds_list, dim=0)
|
||||
prompt_embeds_mask = torch.cat(prompt_embeds_mask_list, dim=0)
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
batch_size: int = 1,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||
prompt_embeds_2: Optional[torch.Tensor] = None,
|
||||
prompt_embeds_mask_2: Optional[torch.Tensor] = None,
|
||||
):
|
||||
r"""
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
batch_size (`int`):
|
||||
batch size of prompts, defaults to 1
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. If not provided, text embeddings will be generated from `prompt` input
|
||||
argument.
|
||||
prompt_embeds_mask (`torch.Tensor`, *optional*):
|
||||
Pre-generated text mask. If not provided, text mask will be generated from `prompt` input argument.
|
||||
prompt_embeds_2 (`torch.Tensor`, *optional*):
|
||||
Pre-generated glyph text embeddings from ByT5. If not provided, will be generated from `prompt` input
|
||||
argument using self.tokenizer_2 and self.text_encoder_2.
|
||||
prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
|
||||
Pre-generated glyph text mask from ByT5. If not provided, will be generated from `prompt` input
|
||||
argument using self.tokenizer_2 and self.text_encoder_2.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
if prompt is None:
|
||||
prompt = [""] * batch_size
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds, prompt_embeds_mask = self._get_mllm_prompt_embeds(
|
||||
tokenizer=self.tokenizer,
|
||||
text_encoder=self.text_encoder,
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
tokenizer_max_length=self.tokenizer_max_length,
|
||||
system_message=self.system_message,
|
||||
crop_start=self.prompt_template_encode_start_idx,
|
||||
)
|
||||
|
||||
if prompt_embeds_2 is None:
|
||||
prompt_embeds_2, prompt_embeds_mask_2 = self._get_byt5_prompt_embeds(
|
||||
tokenizer=self.tokenizer_2,
|
||||
text_encoder=self.text_encoder_2,
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
tokenizer_max_length=self.tokenizer_2_max_length,
|
||||
)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_videos_per_prompt, seq_len)
|
||||
|
||||
_, seq_len_2, _ = prompt_embeds_2.shape
|
||||
prompt_embeds_2 = prompt_embeds_2.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds_2 = prompt_embeds_2.view(batch_size * num_videos_per_prompt, seq_len_2, -1)
|
||||
prompt_embeds_mask_2 = prompt_embeds_mask_2.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds_mask_2 = prompt_embeds_mask_2.view(batch_size * num_videos_per_prompt, seq_len_2)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
prompt_embeds_mask = prompt_embeds_mask.to(dtype=dtype, device=device)
|
||||
prompt_embeds_2 = prompt_embeds_2.to(dtype=dtype, device=device)
|
||||
prompt_embeds_mask_2 = prompt_embeds_mask_2.to(dtype=dtype, device=device)
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
prompt_embeds_mask=None,
|
||||
negative_prompt_embeds_mask=None,
|
||||
prompt_embeds_2=None,
|
||||
prompt_embeds_mask_2=None,
|
||||
negative_prompt_embeds_2=None,
|
||||
negative_prompt_embeds_mask_2=None,
|
||||
):
|
||||
if height is None and width is not None:
|
||||
raise ValueError("If `width` is provided, `height` also have to be provided.")
|
||||
elif width is None and height is not None:
|
||||
raise ValueError("If `height` is provided, `width` also have to be provided.")
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and prompt_embeds_mask is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
|
||||
)
|
||||
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
|
||||
raise ValueError(
|
||||
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
if prompt is None and prompt_embeds_2 is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined."
|
||||
)
|
||||
|
||||
if prompt_embeds_2 is not None and prompt_embeds_mask_2 is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds_2` are provided, `prompt_embeds_mask_2` also have to be passed. Make sure to generate `prompt_embeds_mask_2` from the same text encoder that was used to generate `prompt_embeds_2`."
|
||||
)
|
||||
if negative_prompt_embeds_2 is not None and negative_prompt_embeds_mask_2 is None:
|
||||
raise ValueError(
|
||||
"If `negative_prompt_embeds_2` are provided, `negative_prompt_embeds_mask_2` also have to be passed. Make sure to generate `negative_prompt_embeds_mask_2` from the same text encoder that was used to generate `negative_prompt_embeds_2`."
|
||||
)
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_channels_latents: int = 32,
|
||||
height: int = 720,
|
||||
width: int = 1280,
|
||||
num_frames: int = 129,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
|
||||
int(height) // self.vae_scale_factor_spatial,
|
||||
int(width) // self.vae_scale_factor_spatial,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
return latents
|
||||
|
||||
def prepare_cond_latents_and_mask(self, latents, dtype: Optional[torch.dtype], device: Optional[torch.device]):
|
||||
"""
|
||||
Prepare conditional latents and mask for t2v generation.
|
||||
|
||||
Args:
|
||||
latents: Main latents tensor (B, C, F, H, W)
|
||||
|
||||
Returns:
|
||||
tuple: (cond_latents_concat, mask_concat) - both are zero tensors for t2v
|
||||
"""
|
||||
batch, channels, frames, height, width = latents.shape
|
||||
|
||||
cond_latents_concat = torch.zeros(batch, channels, frames, height, width, dtype=dtype, device=device)
|
||||
|
||||
mask_concat = torch.zeros(batch, 1, frames, height, width, dtype=dtype, device=device)
|
||||
|
||||
return cond_latents_concat, mask_concat
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Union[str, List[str]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_frames: int = 121,
|
||||
num_inference_steps: int = 50,
|
||||
sigmas: List[float] = None,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||
prompt_embeds_2: Optional[torch.Tensor] = None,
|
||||
prompt_embeds_mask_2: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds_2: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds_mask_2: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "np",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`
|
||||
instead.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated video.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated video.
|
||||
num_frames (`int`, defaults to `121`):
|
||||
The number of frames in the generated video.
|
||||
num_inference_steps (`int`, defaults to `50`):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality video at the
|
||||
expense of slower inference.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of videos to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor is generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
||||
provided, text embeddings are generated from the `prompt` input argument.
|
||||
prompt_embeds_mask (`torch.Tensor`, *optional*):
|
||||
Pre-generated mask for prompt embeddings.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
negative_prompt_embeds_mask (`torch.Tensor`, *optional*):
|
||||
Pre-generated mask for negative prompt embeddings.
|
||||
prompt_embeds_2 (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings from the second text encoder. Can be used to easily tweak text inputs.
|
||||
prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
|
||||
Pre-generated mask for prompt embeddings from the second text encoder.
|
||||
negative_prompt_embeds_2 (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings from the second text encoder.
|
||||
negative_prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
|
||||
Pre-generated mask for negative prompt embeddings from the second text encoder.
|
||||
output_type (`str`, *optional*, defaults to `"np"`):
|
||||
The output format of the generated video. Choose between "np", "pt", or "latent".
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`HunyuanVideo15PipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~HunyuanVideo15PipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`HunyuanVideo15PipelineOutput`] is returned, otherwise a `tuple` is
|
||||
returned where the first element is a list with the generated videos.
|
||||
"""
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_embeds_mask=prompt_embeds_mask,
|
||||
negative_prompt_embeds_mask=negative_prompt_embeds_mask,
|
||||
prompt_embeds_2=prompt_embeds_2,
|
||||
prompt_embeds_mask_2=prompt_embeds_mask_2,
|
||||
negative_prompt_embeds_2=negative_prompt_embeds_2,
|
||||
negative_prompt_embeds_mask_2=negative_prompt_embeds_mask_2,
|
||||
)
|
||||
|
||||
if height is None and width is None:
|
||||
height, width = self.video_processor.calculate_default_height_width(
|
||||
self.default_aspect_ratio[1], self.default_aspect_ratio[0], self.target_size
|
||||
)
|
||||
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2 = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
dtype=self.transformer.dtype,
|
||||
batch_size=batch_size,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_embeds_mask=prompt_embeds_mask,
|
||||
prompt_embeds_2=prompt_embeds_2,
|
||||
prompt_embeds_mask_2=prompt_embeds_mask_2,
|
||||
)
|
||||
|
||||
if self.guider._enabled and self.guider.num_conditions > 1:
|
||||
(
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_embeds_mask,
|
||||
negative_prompt_embeds_2,
|
||||
negative_prompt_embeds_mask_2,
|
||||
) = self.encode_prompt(
|
||||
prompt=negative_prompt,
|
||||
device=device,
|
||||
dtype=self.transformer.dtype,
|
||||
batch_size=batch_size,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
prompt_embeds_mask=negative_prompt_embeds_mask,
|
||||
prompt_embeds_2=negative_prompt_embeds_2,
|
||||
prompt_embeds_mask_2=negative_prompt_embeds_mask_2,
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
self.num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
self.transformer.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
cond_latents_concat, mask_concat = self.prepare_cond_latents_and_mask(latents, self.transformer.dtype, device)
|
||||
image_embeds = torch.zeros(
|
||||
batch_size,
|
||||
self.vision_num_semantic_tokens,
|
||||
self.vision_states_dim,
|
||||
dtype=self.transformer.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = torch.cat([latents, cond_latents_concat, mask_concat], dim=1)
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
|
||||
|
||||
# Step 1: Collect model inputs needed for the guidance method
|
||||
# conditional inputs should always be first element in the tuple
|
||||
guider_inputs = {
|
||||
"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds),
|
||||
"encoder_attention_mask": (prompt_embeds_mask, negative_prompt_embeds_mask),
|
||||
"encoder_hidden_states_2": (prompt_embeds_2, negative_prompt_embeds_2),
|
||||
"encoder_attention_mask_2": (prompt_embeds_mask_2, negative_prompt_embeds_mask_2),
|
||||
}
|
||||
|
||||
# Step 2: Update guider's internal state for this denoising step
|
||||
self.guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t)
|
||||
|
||||
# Step 3: Prepare batched model inputs based on the guidance method
|
||||
# The guider splits model inputs into separate batches for conditional/unconditional predictions.
|
||||
# For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
|
||||
# you will get a guider_state with two batches:
|
||||
# guider_state = [
|
||||
# {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
|
||||
# {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
|
||||
# ]
|
||||
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
|
||||
guider_state = self.guider.prepare_inputs(guider_inputs)
|
||||
# Step 4: Run the denoiser for each batch
|
||||
# Each batch in guider_state represents a different conditioning (conditional, unconditional, etc.).
|
||||
# We run the model once per batch and store the noise prediction in guider_state_batch.noise_pred.
|
||||
for guider_state_batch in guider_state:
|
||||
self.guider.prepare_models(self.transformer)
|
||||
|
||||
# Extract conditioning kwargs for this batch (e.g., encoder_hidden_states)
|
||||
cond_kwargs = {
|
||||
input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()
|
||||
}
|
||||
|
||||
# e.g. "pred_cond"/"pred_uncond"
|
||||
context_name = getattr(guider_state_batch, self.guider._identifier_key)
|
||||
with self.transformer.cache_context(context_name):
|
||||
# Run denoiser and store noise prediction in this batch
|
||||
guider_state_batch.noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
image_embeds=image_embeds,
|
||||
timestep=timestep,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
**cond_kwargs,
|
||||
)[0]
|
||||
|
||||
# Cleanup model (e.g., remove hooks)
|
||||
self.guider.cleanup_models(self.transformer)
|
||||
|
||||
# Step 5: Combine predictions using the guidance method
|
||||
# The guider takes all noise predictions from guider_state and combines them according to the guidance algorithm.
|
||||
# Continuing the CFG example, the guider receives:
|
||||
# guider_state = [
|
||||
# {"encoder_hidden_states": prompt_embeds, "noise_pred": noise_pred_cond, "__guidance_identifier__": "pred_cond"}, # batch 0
|
||||
# {"encoder_hidden_states": negative_prompt_embeds, "noise_pred": noise_pred_uncond, "__guidance_identifier__": "pred_uncond"}, # batch 1
|
||||
# ]
|
||||
# And extracts predictions using the __guidance_identifier__:
|
||||
# pred_cond = guider_state[0]["noise_pred"] # extracts noise_pred_cond
|
||||
# pred_uncond = guider_state[1]["noise_pred"] # extracts noise_pred_uncond
|
||||
# Then applies CFG formula:
|
||||
# noise_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
|
||||
# Returns GuiderOutput(pred=noise_pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
|
||||
noise_pred = self.guider(guider_state)[0]
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
# 8. decode the latents to video and postprocess
|
||||
if not output_type == "latent":
|
||||
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return HunyuanVideo15PipelineOutput(frames=video)
|
||||
@@ -0,0 +1,950 @@
|
||||
# Copyright 2025 The HunyuanVideo Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import (
|
||||
ByT5Tokenizer,
|
||||
Qwen2_5_VLTextModel,
|
||||
Qwen2Tokenizer,
|
||||
SiglipImageProcessor,
|
||||
SiglipVisionModel,
|
||||
T5EncoderModel,
|
||||
)
|
||||
|
||||
from ...guiders import ClassifierFreeGuidance
|
||||
from ...models import AutoencoderKLHunyuanVideo15, HunyuanVideo15Transformer3DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .image_processor import HunyuanVideo15ImageProcessor
|
||||
from .pipeline_output import HunyuanVideo15PipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import HunyuanVideo15ImageToVideoPipeline
|
||||
>>> from diffusers.utils import export_to_video
|
||||
|
||||
>>> model_id = "hunyuanvideo-community/HunyuanVideo-1.5-480p_i2v"
|
||||
>>> pipe = HunyuanVideo15ImageToVideoPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
|
||||
>>> pipe.vae.enable_tiling()
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/wan_i2v_input.JPG")
|
||||
|
||||
>>> output = pipe(
|
||||
... prompt="Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
|
||||
... image=image,
|
||||
... num_inference_steps=50,
|
||||
... ).frames[0]
|
||||
>>> export_to_video(output, "output.mp4", fps=24)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.format_text_input
|
||||
def format_text_input(prompt: List[str], system_message: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Apply text to template.
|
||||
|
||||
Args:
|
||||
prompt (List[str]): Input text.
|
||||
system_message (str): System message.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: List of chat conversation.
|
||||
"""
|
||||
|
||||
template = [
|
||||
[{"role": "system", "content": system_message}, {"role": "user", "content": p if p else " "}] for p in prompt
|
||||
]
|
||||
|
||||
return template
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.extract_glyph_texts
|
||||
def extract_glyph_texts(prompt: str) -> List[str]:
|
||||
"""
|
||||
Extract glyph texts from prompt using regex pattern.
|
||||
|
||||
Args:
|
||||
prompt: Input prompt string
|
||||
|
||||
Returns:
|
||||
List of extracted glyph texts
|
||||
"""
|
||||
pattern = r"\"(.*?)\"|“(.*?)”"
|
||||
matches = re.findall(pattern, prompt)
|
||||
result = [match[0] or match[1] for match in matches]
|
||||
result = list(dict.fromkeys(result)) if len(result) > 1 else result
|
||||
|
||||
if result:
|
||||
formatted_result = ". ".join([f'Text "{text}"' for text in result]) + ". "
|
||||
else:
|
||||
formatted_result = None
|
||||
|
||||
return formatted_result
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class HunyuanVideo15ImageToVideoPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for image-to-video generation using HunyuanVideo1.5.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
Args:
|
||||
transformer ([`HunyuanVideo15Transformer3DModel`]):
|
||||
Conditional Transformer (MMDiT) architecture to denoise the encoded video latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
|
||||
vae ([`AutoencoderKLHunyuanVideo15`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
|
||||
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
|
||||
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
|
||||
tokenizer (`Qwen2Tokenizer`): Tokenizer of class [Qwen2Tokenizer].
|
||||
text_encoder_2 ([`T5EncoderModel`]):
|
||||
[T5EncoderModel](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel)
|
||||
variant.
|
||||
tokenizer_2 (`ByT5Tokenizer`): Tokenizer of class [ByT5Tokenizer]
|
||||
guider ([`ClassifierFreeGuidance`]):
|
||||
[ClassifierFreeGuidance]for classifier free guidance.
|
||||
image_encoder ([`SiglipVisionModel`]):
|
||||
[SiglipVisionModel](https://huggingface.co/docs/transformers/en/model_doc/siglip#transformers.SiglipVisionModel)
|
||||
variant.
|
||||
feature_extractor ([`SiglipImageProcessor`]):
|
||||
[SiglipImageProcessor](https://huggingface.co/docs/transformers/en/model_doc/siglip#transformers.SiglipImageProcessor)
|
||||
variant.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "image_encoder->text_encoder->transformer->vae"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_encoder: Qwen2_5_VLTextModel,
|
||||
tokenizer: Qwen2Tokenizer,
|
||||
transformer: HunyuanVideo15Transformer3DModel,
|
||||
vae: AutoencoderKLHunyuanVideo15,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
text_encoder_2: T5EncoderModel,
|
||||
tokenizer_2: ByT5Tokenizer,
|
||||
guider: ClassifierFreeGuidance,
|
||||
image_encoder: SiglipVisionModel,
|
||||
feature_extractor: SiglipImageProcessor,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer_2=tokenizer_2,
|
||||
guider=guider,
|
||||
image_encoder=image_encoder,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
|
||||
self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4
|
||||
self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 16
|
||||
self.video_processor = HunyuanVideo15ImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor_spatial, do_resize=False, do_convert_rgb=True
|
||||
)
|
||||
self.target_size = self.transformer.config.target_size if getattr(self, "transformer", None) else 640
|
||||
self.vision_states_dim = (
|
||||
self.transformer.config.image_embed_dim if getattr(self, "transformer", None) else 1152
|
||||
)
|
||||
self.num_channels_latents = self.vae.config.latent_channels if hasattr(self, "vae") else 32
|
||||
# fmt: off
|
||||
self.system_message = "You are a helpful assistant. Describe the video by detailing the following aspects: \
|
||||
1. The main content and theme of the video. \
|
||||
2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \
|
||||
3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \
|
||||
4. background environment, light, style and atmosphere. \
|
||||
5. camera angles, movements, and transitions used in the video."
|
||||
# fmt: on
|
||||
self.prompt_template_encode_start_idx = 108
|
||||
self.tokenizer_max_length = 1000
|
||||
self.tokenizer_2_max_length = 256
|
||||
self.vision_num_semantic_tokens = 729
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline._get_mllm_prompt_embeds
|
||||
def _get_mllm_prompt_embeds(
|
||||
text_encoder: Qwen2_5_VLTextModel,
|
||||
tokenizer: Qwen2Tokenizer,
|
||||
prompt: Union[str, List[str]],
|
||||
device: torch.device,
|
||||
tokenizer_max_length: int = 1000,
|
||||
num_hidden_layers_to_skip: int = 2,
|
||||
# fmt: off
|
||||
system_message: str = "You are a helpful assistant. Describe the video by detailing the following aspects: \
|
||||
1. The main content and theme of the video. \
|
||||
2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \
|
||||
3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \
|
||||
4. background environment, light, style and atmosphere. \
|
||||
5. camera angles, movements, and transitions used in the video.",
|
||||
# fmt: on
|
||||
crop_start: int = 108,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
prompt = format_text_input(prompt, system_message)
|
||||
|
||||
text_inputs = tokenizer.apply_chat_template(
|
||||
prompt,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
padding="max_length",
|
||||
max_length=tokenizer_max_length + crop_start,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids.to(device=device)
|
||||
prompt_attention_mask = text_inputs.attention_mask.to(device=device)
|
||||
|
||||
prompt_embeds = text_encoder(
|
||||
input_ids=text_input_ids,
|
||||
attention_mask=prompt_attention_mask,
|
||||
output_hidden_states=True,
|
||||
).hidden_states[-(num_hidden_layers_to_skip + 1)]
|
||||
|
||||
if crop_start is not None and crop_start > 0:
|
||||
prompt_embeds = prompt_embeds[:, crop_start:]
|
||||
prompt_attention_mask = prompt_attention_mask[:, crop_start:]
|
||||
|
||||
return prompt_embeds, prompt_attention_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline._get_byt5_prompt_embeds
|
||||
def _get_byt5_prompt_embeds(
|
||||
tokenizer: ByT5Tokenizer,
|
||||
text_encoder: T5EncoderModel,
|
||||
prompt: Union[str, List[str]],
|
||||
device: torch.device,
|
||||
tokenizer_max_length: int = 256,
|
||||
):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
glyph_texts = [extract_glyph_texts(p) for p in prompt]
|
||||
|
||||
prompt_embeds_list = []
|
||||
prompt_embeds_mask_list = []
|
||||
|
||||
for glyph_text in glyph_texts:
|
||||
if glyph_text is None:
|
||||
glyph_text_embeds = torch.zeros(
|
||||
(1, tokenizer_max_length, text_encoder.config.d_model), device=device, dtype=text_encoder.dtype
|
||||
)
|
||||
glyph_text_embeds_mask = torch.zeros((1, tokenizer_max_length), device=device, dtype=torch.int64)
|
||||
else:
|
||||
txt_tokens = tokenizer(
|
||||
glyph_text,
|
||||
padding="max_length",
|
||||
max_length=tokenizer_max_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
).to(device)
|
||||
|
||||
glyph_text_embeds = text_encoder(
|
||||
input_ids=txt_tokens.input_ids,
|
||||
attention_mask=txt_tokens.attention_mask.float(),
|
||||
)[0]
|
||||
glyph_text_embeds = glyph_text_embeds.to(device=device)
|
||||
glyph_text_embeds_mask = txt_tokens.attention_mask.to(device=device)
|
||||
|
||||
prompt_embeds_list.append(glyph_text_embeds)
|
||||
prompt_embeds_mask_list.append(glyph_text_embeds_mask)
|
||||
|
||||
prompt_embeds = torch.cat(prompt_embeds_list, dim=0)
|
||||
prompt_embeds_mask = torch.cat(prompt_embeds_mask_list, dim=0)
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
@staticmethod
|
||||
def _get_image_latents(
|
||||
vae: AutoencoderKLHunyuanVideo15,
|
||||
image_processor: HunyuanVideo15ImageProcessor,
|
||||
image: PIL.Image.Image,
|
||||
height: int,
|
||||
width: int,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
vae_dtype = vae.dtype
|
||||
image_tensor = image_processor.preprocess(image, height=height, width=width).to(device, dtype=vae_dtype)
|
||||
image_tensor = image_tensor.unsqueeze(2)
|
||||
image_latents = retrieve_latents(vae.encode(image_tensor), sample_mode="argmax")
|
||||
image_latents = image_latents * vae.config.scaling_factor
|
||||
return image_latents
|
||||
|
||||
@staticmethod
|
||||
def _get_image_embeds(
|
||||
image_encoder: SiglipVisionModel,
|
||||
feature_extractor: SiglipImageProcessor,
|
||||
image: PIL.Image.Image,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
image_encoder_dtype = next(image_encoder.parameters()).dtype
|
||||
image = feature_extractor.preprocess(images=image, do_resize=True, return_tensors="pt", do_convert_rgb=True)
|
||||
image = image.to(device=device, dtype=image_encoder_dtype)
|
||||
image_enc_hidden_states = image_encoder(**image).last_hidden_state
|
||||
|
||||
return image_enc_hidden_states
|
||||
|
||||
def encode_image(
|
||||
self,
|
||||
image: PIL.Image.Image,
|
||||
batch_size: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
image_embeds = self._get_image_embeds(
|
||||
image_encoder=self.image_encoder,
|
||||
feature_extractor=self.feature_extractor,
|
||||
image=image,
|
||||
device=device,
|
||||
)
|
||||
image_embeds = image_embeds.repeat(batch_size, 1, 1)
|
||||
image_embeds = image_embeds.to(device=device, dtype=dtype)
|
||||
return image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
batch_size: int = 1,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||
prompt_embeds_2: Optional[torch.Tensor] = None,
|
||||
prompt_embeds_mask_2: Optional[torch.Tensor] = None,
|
||||
):
|
||||
r"""
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
batch_size (`int`):
|
||||
batch size of prompts, defaults to 1
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. If not provided, text embeddings will be generated from `prompt` input
|
||||
argument.
|
||||
prompt_embeds_mask (`torch.Tensor`, *optional*):
|
||||
Pre-generated text mask. If not provided, text mask will be generated from `prompt` input argument.
|
||||
prompt_embeds_2 (`torch.Tensor`, *optional*):
|
||||
Pre-generated glyph text embeddings from ByT5. If not provided, will be generated from `prompt` input
|
||||
argument using self.tokenizer_2 and self.text_encoder_2.
|
||||
prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
|
||||
Pre-generated glyph text mask from ByT5. If not provided, will be generated from `prompt` input
|
||||
argument using self.tokenizer_2 and self.text_encoder_2.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
if prompt is None:
|
||||
prompt = [""] * batch_size
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds, prompt_embeds_mask = self._get_mllm_prompt_embeds(
|
||||
tokenizer=self.tokenizer,
|
||||
text_encoder=self.text_encoder,
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
tokenizer_max_length=self.tokenizer_max_length,
|
||||
system_message=self.system_message,
|
||||
crop_start=self.prompt_template_encode_start_idx,
|
||||
)
|
||||
|
||||
if prompt_embeds_2 is None:
|
||||
prompt_embeds_2, prompt_embeds_mask_2 = self._get_byt5_prompt_embeds(
|
||||
tokenizer=self.tokenizer_2,
|
||||
text_encoder=self.text_encoder_2,
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
tokenizer_max_length=self.tokenizer_2_max_length,
|
||||
)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_videos_per_prompt, seq_len)
|
||||
|
||||
_, seq_len_2, _ = prompt_embeds_2.shape
|
||||
prompt_embeds_2 = prompt_embeds_2.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds_2 = prompt_embeds_2.view(batch_size * num_videos_per_prompt, seq_len_2, -1)
|
||||
prompt_embeds_mask_2 = prompt_embeds_mask_2.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds_mask_2 = prompt_embeds_mask_2.view(batch_size * num_videos_per_prompt, seq_len_2)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
prompt_embeds_mask = prompt_embeds_mask.to(dtype=dtype, device=device)
|
||||
prompt_embeds_2 = prompt_embeds_2.to(dtype=dtype, device=device)
|
||||
prompt_embeds_mask_2 = prompt_embeds_mask_2.to(dtype=dtype, device=device)
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
image: PIL.Image.Image,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
prompt_embeds_mask=None,
|
||||
negative_prompt_embeds_mask=None,
|
||||
prompt_embeds_2=None,
|
||||
prompt_embeds_mask_2=None,
|
||||
negative_prompt_embeds_2=None,
|
||||
negative_prompt_embeds_mask_2=None,
|
||||
):
|
||||
if not isinstance(image, PIL.Image.Image):
|
||||
raise ValueError(f"`image` has to be of type `PIL.Image.Image` but is {type(image)}")
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and prompt_embeds_mask is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
|
||||
)
|
||||
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
|
||||
raise ValueError(
|
||||
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
if prompt is None and prompt_embeds_2 is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined."
|
||||
)
|
||||
|
||||
if prompt_embeds_2 is not None and prompt_embeds_mask_2 is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds_2` are provided, `prompt_embeds_mask_2` also have to be passed. Make sure to generate `prompt_embeds_mask_2` from the same text encoder that was used to generate `prompt_embeds_2`."
|
||||
)
|
||||
if negative_prompt_embeds_2 is not None and negative_prompt_embeds_mask_2 is None:
|
||||
raise ValueError(
|
||||
"If `negative_prompt_embeds_2` are provided, `negative_prompt_embeds_mask_2` also have to be passed. Make sure to generate `negative_prompt_embeds_mask_2` from the same text encoder that was used to generate `negative_prompt_embeds_2`."
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_channels_latents: int = 32,
|
||||
height: int = 720,
|
||||
width: int = 1280,
|
||||
num_frames: int = 129,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
|
||||
int(height) // self.vae_scale_factor_spatial,
|
||||
int(width) // self.vae_scale_factor_spatial,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
return latents
|
||||
|
||||
def prepare_cond_latents_and_mask(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
image: PIL.Image.Image,
|
||||
batch_size: int,
|
||||
height: int,
|
||||
width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
):
|
||||
"""
|
||||
Prepare conditional latents and mask for t2v generation.
|
||||
|
||||
Args:
|
||||
latents: Main latents tensor (B, C, F, H, W)
|
||||
|
||||
Returns:
|
||||
tuple: (cond_latents_concat, mask_concat) - both are zero tensors for t2v
|
||||
"""
|
||||
|
||||
batch, channels, frames, height, width = latents.shape
|
||||
|
||||
image_latents = self._get_image_latents(
|
||||
vae=self.vae,
|
||||
image_processor=self.video_processor,
|
||||
image=image,
|
||||
height=height,
|
||||
width=width,
|
||||
device=device,
|
||||
)
|
||||
|
||||
latent_condition = image_latents.repeat(batch_size, 1, frames, 1, 1)
|
||||
latent_condition[:, :, 1:, :, :] = 0
|
||||
latent_condition = latent_condition.to(device=device, dtype=dtype)
|
||||
|
||||
latent_mask = torch.zeros(batch, 1, frames, height, width, dtype=dtype, device=device)
|
||||
latent_mask[:, :, 0, :, :] = 1.0
|
||||
|
||||
return latent_condition, latent_mask
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
image: PIL.Image.Image,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Union[str, List[str]] = None,
|
||||
num_frames: int = 121,
|
||||
num_inference_steps: int = 50,
|
||||
sigmas: List[float] = None,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||
prompt_embeds_2: Optional[torch.Tensor] = None,
|
||||
prompt_embeds_mask_2: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds_2: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds_mask_2: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "np",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
image (`PIL.Image.Image`):
|
||||
The input image to condition video generation on.
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`
|
||||
instead.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the video generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead.
|
||||
num_frames (`int`, defaults to `121`):
|
||||
The number of frames in the generated video.
|
||||
num_inference_steps (`int`, defaults to `50`):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality video at the
|
||||
expense of slower inference.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of videos to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor is generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
||||
provided, text embeddings are generated from the `prompt` input argument.
|
||||
prompt_embeds_mask (`torch.Tensor`, *optional*):
|
||||
Pre-generated mask for prompt embeddings.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
negative_prompt_embeds_mask (`torch.Tensor`, *optional*):
|
||||
Pre-generated mask for negative prompt embeddings.
|
||||
prompt_embeds_2 (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings from the second text encoder. Can be used to easily tweak text inputs.
|
||||
prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
|
||||
Pre-generated mask for prompt embeddings from the second text encoder.
|
||||
negative_prompt_embeds_2 (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings from the second text encoder.
|
||||
negative_prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
|
||||
Pre-generated mask for negative prompt embeddings from the second text encoder.
|
||||
output_type (`str`, *optional*, defaults to `"np"`):
|
||||
The output format of the generated video. Choose between "np", "pt", or "latent".
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`HunyuanVideo15PipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~HunyuanVideo15PipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`HunyuanVideo15PipelineOutput`] is returned, otherwise a `tuple` is
|
||||
returned where the first element is a list with the generated videos.
|
||||
"""
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_embeds_mask=prompt_embeds_mask,
|
||||
negative_prompt_embeds_mask=negative_prompt_embeds_mask,
|
||||
prompt_embeds_2=prompt_embeds_2,
|
||||
prompt_embeds_mask_2=prompt_embeds_mask_2,
|
||||
negative_prompt_embeds_2=negative_prompt_embeds_2,
|
||||
negative_prompt_embeds_mask_2=negative_prompt_embeds_mask_2,
|
||||
)
|
||||
|
||||
height, width = self.video_processor.calculate_default_height_width(
|
||||
height=image.size[1], width=image.size[0], target_size=self.target_size
|
||||
)
|
||||
image = self.video_processor.resize(image, height=height, width=width, resize_mode="crop")
|
||||
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# 3. Encode image
|
||||
image_embeds = self.encode_image(
|
||||
image=image,
|
||||
batch_size=batch_size * num_videos_per_prompt,
|
||||
device=device,
|
||||
dtype=self.transformer.dtype,
|
||||
)
|
||||
|
||||
# 4. Encode input prompt
|
||||
prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2 = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
dtype=self.transformer.dtype,
|
||||
batch_size=batch_size,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_embeds_mask=prompt_embeds_mask,
|
||||
prompt_embeds_2=prompt_embeds_2,
|
||||
prompt_embeds_mask_2=prompt_embeds_mask_2,
|
||||
)
|
||||
|
||||
if self.guider._enabled and self.guider.num_conditions > 1:
|
||||
(
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_embeds_mask,
|
||||
negative_prompt_embeds_2,
|
||||
negative_prompt_embeds_mask_2,
|
||||
) = self.encode_prompt(
|
||||
prompt=negative_prompt,
|
||||
device=device,
|
||||
dtype=self.transformer.dtype,
|
||||
batch_size=batch_size,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
prompt_embeds_mask=negative_prompt_embeds_mask,
|
||||
prompt_embeds_2=negative_prompt_embeds_2,
|
||||
prompt_embeds_mask_2=negative_prompt_embeds_mask_2,
|
||||
)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
|
||||
|
||||
# 6. Prepare latent variables
|
||||
latents = self.prepare_latents(
|
||||
batch_size=batch_size * num_videos_per_prompt,
|
||||
num_channels_latents=self.num_channels_latents,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
dtype=self.transformer.dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
)
|
||||
|
||||
cond_latents_concat, mask_concat = self.prepare_cond_latents_and_mask(
|
||||
latents=latents,
|
||||
image=image,
|
||||
batch_size=batch_size * num_videos_per_prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
dtype=self.transformer.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = torch.cat([latents, cond_latents_concat, mask_concat], dim=1)
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
|
||||
|
||||
# Step 1: Collect model inputs needed for the guidance method
|
||||
# conditional inputs should always be first element in the tuple
|
||||
guider_inputs = {
|
||||
"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds),
|
||||
"encoder_attention_mask": (prompt_embeds_mask, negative_prompt_embeds_mask),
|
||||
"encoder_hidden_states_2": (prompt_embeds_2, negative_prompt_embeds_2),
|
||||
"encoder_attention_mask_2": (prompt_embeds_mask_2, negative_prompt_embeds_mask_2),
|
||||
}
|
||||
|
||||
# Step 2: Update guider's internal state for this denoising step
|
||||
self.guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t)
|
||||
|
||||
# Step 3: Prepare batched model inputs based on the guidance method
|
||||
# The guider splits model inputs into separate batches for conditional/unconditional predictions.
|
||||
# For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
|
||||
# you will get a guider_state with two batches:
|
||||
# guider_state = [
|
||||
# {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
|
||||
# {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
|
||||
# ]
|
||||
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
|
||||
guider_state = self.guider.prepare_inputs(guider_inputs)
|
||||
# Step 4: Run the denoiser for each batch
|
||||
# Each batch in guider_state represents a different conditioning (conditional, unconditional, etc.).
|
||||
# We run the model once per batch and store the noise prediction in guider_state_batch.noise_pred.
|
||||
for guider_state_batch in guider_state:
|
||||
self.guider.prepare_models(self.transformer)
|
||||
|
||||
# Extract conditioning kwargs for this batch (e.g., encoder_hidden_states)
|
||||
cond_kwargs = {
|
||||
input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()
|
||||
}
|
||||
|
||||
# e.g. "pred_cond"/"pred_uncond"
|
||||
context_name = getattr(guider_state_batch, self.guider._identifier_key)
|
||||
with self.transformer.cache_context(context_name):
|
||||
# Run denoiser and store noise prediction in this batch
|
||||
guider_state_batch.noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
image_embeds=image_embeds,
|
||||
timestep=timestep,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
**cond_kwargs,
|
||||
)[0]
|
||||
|
||||
# Cleanup model (e.g., remove hooks)
|
||||
self.guider.cleanup_models(self.transformer)
|
||||
|
||||
# Step 5: Combine predictions using the guidance method
|
||||
# The guider takes all noise predictions from guider_state and combines them according to the guidance algorithm.
|
||||
# Continuing the CFG example, the guider receives:
|
||||
# guider_state = [
|
||||
# {"encoder_hidden_states": prompt_embeds, "noise_pred": noise_pred_cond, "__guidance_identifier__": "pred_cond"}, # batch 0
|
||||
# {"encoder_hidden_states": negative_prompt_embeds, "noise_pred": noise_pred_uncond, "__guidance_identifier__": "pred_uncond"}, # batch 1
|
||||
# ]
|
||||
# And extracts predictions using the __guidance_identifier__:
|
||||
# pred_cond = guider_state[0]["noise_pred"] # extracts noise_pred_cond
|
||||
# pred_uncond = guider_state[1]["noise_pred"] # extracts noise_pred_uncond
|
||||
# Then applies CFG formula:
|
||||
# noise_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
|
||||
# Returns GuiderOutput(pred=noise_pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
|
||||
noise_pred = self.guider(guider_state)[0]
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return HunyuanVideo15PipelineOutput(frames=video)
|
||||
20
src/diffusers/pipelines/hunyuan_video1_5/pipeline_output.py
Normal file
20
src/diffusers/pipelines/hunyuan_video1_5/pipeline_output.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class HunyuanVideo15PipelineOutput(BaseOutput):
|
||||
r"""
|
||||
Output class for HunyuanVideo1.5 pipelines.
|
||||
|
||||
Args:
|
||||
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
||||
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
||||
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
||||
`(batch_size, num_frames, channels, height, width)`.
|
||||
"""
|
||||
|
||||
frames: torch.Tensor
|
||||
@@ -468,6 +468,21 @@ class AutoencoderKLHunyuanVideo(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLHunyuanVideo15(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLLTXVideo(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -993,6 +1008,21 @@ class HunyuanImageTransformer2DModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class HunyuanVideo15Transformer3DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class HunyuanVideoFramepackTransformer3DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -1142,6 +1142,36 @@ class HunyuanSkyreelsImageToVideoPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class HunyuanVideo15ImageToVideoPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class HunyuanVideo15Pipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class HunyuanVideoFramepackPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
100
tests/models/transformers/test_models_transformer_hunyuan_1_5.py
Normal file
100
tests/models/transformers/test_models_transformer_hunyuan_1_5.py
Normal file
@@ -0,0 +1,100 @@
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import HunyuanVideo15Transformer3DModel
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class HunyuanVideo15Transformer3DTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = HunyuanVideo15Transformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
text_embed_dim = 16
|
||||
text_embed_2_dim = 8
|
||||
image_embed_dim = 12
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 1
|
||||
num_channels = 4
|
||||
num_frames = 1
|
||||
height = 8
|
||||
width = 8
|
||||
sequence_length = 6
|
||||
sequence_length_2 = 4
|
||||
image_sequence_length = 3
|
||||
|
||||
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
|
||||
timestep = torch.tensor([1.0]).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, self.text_embed_dim), device=torch_device)
|
||||
encoder_hidden_states_2 = torch.randn(
|
||||
(batch_size, sequence_length_2, self.text_embed_2_dim), device=torch_device
|
||||
)
|
||||
encoder_attention_mask = torch.ones((batch_size, sequence_length), device=torch_device)
|
||||
encoder_attention_mask_2 = torch.ones((batch_size, sequence_length_2), device=torch_device)
|
||||
# All zeros for inducing T2V path in the model.
|
||||
image_embeds = torch.zeros((batch_size, image_sequence_length, self.image_embed_dim), device=torch_device)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"timestep": timestep,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"encoder_attention_mask": encoder_attention_mask,
|
||||
"encoder_hidden_states_2": encoder_hidden_states_2,
|
||||
"encoder_attention_mask_2": encoder_attention_mask_2,
|
||||
"image_embeds": image_embeds,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 1, 8, 8)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 1, 8, 8)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"in_channels": 4,
|
||||
"out_channels": 4,
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 8,
|
||||
"num_layers": 2,
|
||||
"num_refiner_layers": 1,
|
||||
"mlp_ratio": 2.0,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"text_embed_dim": self.text_embed_dim,
|
||||
"text_embed_2_dim": self.text_embed_2_dim,
|
||||
"image_embed_dim": self.image_embed_dim,
|
||||
"rope_axes_dim": (2, 2, 4),
|
||||
"target_size": 16,
|
||||
"task_type": "t2v",
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"HunyuanVideo15Transformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
1
tests/pipelines/hunyuan_video1_5/__init__.py
Normal file
1
tests/pipelines/hunyuan_video1_5/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Copyright 2025 The HuggingFace Team.
|
||||
187
tests/pipelines/hunyuan_video1_5/test_hunyuan_1_5.py
Normal file
187
tests/pipelines/hunyuan_video1_5/test_hunyuan_1_5.py
Normal file
@@ -0,0 +1,187 @@
|
||||
# Copyright 2025 The HuggingFace Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from transformers import ByT5Tokenizer, Qwen2_5_VLTextConfig, Qwen2_5_VLTextModel, Qwen2Tokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLHunyuanVideo15,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
HunyuanVideo15Pipeline,
|
||||
HunyuanVideo15Transformer3DModel,
|
||||
)
|
||||
from diffusers.guiders import ClassifierFreeGuidance
|
||||
|
||||
from ...testing_utils import enable_full_determinism
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class HunyuanVideo15PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = HunyuanVideo15Pipeline
|
||||
params = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
"negative_prompt",
|
||||
"height",
|
||||
"width",
|
||||
"prompt_embeds",
|
||||
"prompt_embeds_mask",
|
||||
"negative_prompt_embeds",
|
||||
"negative_prompt_embeds_mask",
|
||||
"prompt_embeds_2",
|
||||
"prompt_embeds_mask_2",
|
||||
"negative_prompt_embeds_2",
|
||||
"negative_prompt_embeds_mask_2",
|
||||
]
|
||||
)
|
||||
batch_params = ["prompt", "negative_prompt"]
|
||||
required_optional_params = frozenset(["num_inference_steps", "generator", "latents", "return_dict"])
|
||||
test_attention_slicing = False
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = False
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self, num_layers: int = 1):
|
||||
torch.manual_seed(0)
|
||||
transformer = HunyuanVideo15Transformer3DModel(
|
||||
in_channels=9,
|
||||
out_channels=4,
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=8,
|
||||
num_layers=num_layers,
|
||||
num_refiner_layers=1,
|
||||
mlp_ratio=2.0,
|
||||
patch_size=1,
|
||||
patch_size_t=1,
|
||||
text_embed_dim=16,
|
||||
text_embed_2_dim=32,
|
||||
image_embed_dim=12,
|
||||
rope_axes_dim=(2, 2, 4),
|
||||
target_size=16,
|
||||
task_type="t2v",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLHunyuanVideo15(
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
latent_channels=4,
|
||||
block_out_channels=(16, 16),
|
||||
layers_per_block=1,
|
||||
spatial_compression_ratio=4,
|
||||
temporal_compression_ratio=2,
|
||||
downsample_match_channel=False,
|
||||
upsample_match_channel=False,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
|
||||
|
||||
torch.manual_seed(0)
|
||||
qwen_config = Qwen2_5_VLTextConfig(
|
||||
**{
|
||||
"hidden_size": 16,
|
||||
"intermediate_size": 16,
|
||||
"num_hidden_layers": 2,
|
||||
"num_attention_heads": 2,
|
||||
"num_key_value_heads": 2,
|
||||
"rope_scaling": {
|
||||
"mrope_section": [1, 1, 2],
|
||||
"rope_type": "default",
|
||||
"type": "default",
|
||||
},
|
||||
"rope_theta": 1000000.0,
|
||||
}
|
||||
)
|
||||
text_encoder = Qwen2_5_VLTextModel(qwen_config)
|
||||
tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
tokenizer_2 = ByT5Tokenizer()
|
||||
|
||||
guider = ClassifierFreeGuidance(guidance_scale=1.0)
|
||||
|
||||
components = {
|
||||
"transformer": transformer.eval(),
|
||||
"vae": vae.eval(),
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder.eval(),
|
||||
"text_encoder_2": text_encoder_2.eval(),
|
||||
"tokenizer": tokenizer,
|
||||
"tokenizer_2": tokenizer_2,
|
||||
"guider": guider,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "monkey",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"height": 16,
|
||||
"width": 16,
|
||||
"num_frames": 9,
|
||||
"output_type": "pt",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
result = pipe(**inputs)
|
||||
video = result.frames
|
||||
|
||||
generated_video = video[0]
|
||||
self.assertEqual(generated_video.shape, (9, 3, 16, 16))
|
||||
generated_slice = generated_video.flatten()
|
||||
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
|
||||
|
||||
# fmt: off
|
||||
expected_slice = torch.tensor([0.4296, 0.5549, 0.3088, 0.9115, 0.5049, 0.7926, 0.5549, 0.8618, 0.5091, 0.5075, 0.7117, 0.5292, 0.7053, 0.4864, 0.5206, 0.3878])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(
|
||||
torch.abs(generated_slice - expected_slice).max() < 1e-3,
|
||||
f"output_slice: {generated_slice}, expected_slice: {expected_slice}",
|
||||
)
|
||||
|
||||
@unittest.skip("TODO: Test not supported for now because needs to be adjusted to work with guiders.")
|
||||
def test_encode_prompt_works_in_isolation(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Needs to be revisited.")
|
||||
def test_inference_batch_consistent(self):
|
||||
super().test_inference_batch_consistent()
|
||||
|
||||
@unittest.skip("Needs to be revisited.")
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical()
|
||||
Reference in New Issue
Block a user