Hunyuan Video Framepack F1 (#11534)

* support framepack f1

* update docs

* update toctree

* remove typo
This commit is contained in:
Aryan
2025-05-12 16:11:10 +05:30
committed by GitHub
parent 01abfc8736
commit e48f6aeeb4
4 changed files with 349 additions and 48 deletions

View File

@@ -457,6 +457,8 @@
title: Flux
- local: api/pipelines/control_flux_inpaint
title: FluxControlInpaint
- local: api/pipelines/framepack
title: Framepack
- local: api/pipelines/hidream
title: HiDream-I1
- local: api/pipelines/hunyuandit

View File

@@ -0,0 +1,209 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
# Framepack
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
</div>
[Packing Input Frame Context in Next-Frame Prediction Models for Video Generation](https://arxiv.org/abs/2504.12626) by Lvmin Zhang and Maneesh Agrawala.
*We present a neural network structure, FramePack, to train next-frame (or next-frame-section) prediction models for video generation. The FramePack compresses input frames to make the transformer context length a fixed number regardless of the video length. As a result, we are able to process a large number of frames using video diffusion with computation bottleneck similar to image diffusion. This also makes the training video batch sizes significantly higher (batch sizes become comparable to image diffusion training). We also propose an anti-drifting sampling method that generates frames in inverted temporal order with early-established endpoints to avoid exposure bias (error accumulation over iterations). Finally, we show that existing video diffusion models can be finetuned with FramePack, and their visual quality may be improved because the next-frame prediction supports more balanced diffusion schedulers with less extreme flow shift timesteps.*
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## Available models
| Model name | Description |
|:---|:---|
- [`lllyasviel/FramePackI2V_HY`](https://huggingface.co/lllyasviel/FramePackI2V_HY) | Trained with the "inverted anti-drifting" strategy as described in the paper. Inference requires setting `sampling_type="inverted_anti_drifting"` when running the pipeline. |
- [`lllyasviel/FramePack_F1_I2V_HY_20250503`](https://huggingface.co/lllyasviel/FramePack_F1_I2V_HY_20250503) | Trained with a novel anti-drifting strategy but inference is performed in "vanilla" strategy as described in the paper. Inference requires setting `sampling_type="vanilla"` when running the pipeline. |
## Usage
Refer to the pipeline documentation for basic usage examples. The following section contains examples of offloading, different sampling methods, quantization, and more.
### First and last frame to video
The following example shows how to use Framepack with start and end image controls, using the inverted anti-drifiting sampling model.
```python
import torch
from diffusers import HunyuanVideoFramepackPipeline, HunyuanVideoFramepackTransformer3DModel
from diffusers.utils import export_to_video, load_image
from transformers import SiglipImageProcessor, SiglipVisionModel
transformer = HunyuanVideoFramepackTransformer3DModel.from_pretrained(
"lllyasviel/FramePackI2V_HY", torch_dtype=torch.bfloat16
)
feature_extractor = SiglipImageProcessor.from_pretrained(
"lllyasviel/flux_redux_bfl", subfolder="feature_extractor"
)
image_encoder = SiglipVisionModel.from_pretrained(
"lllyasviel/flux_redux_bfl", subfolder="image_encoder", torch_dtype=torch.float16
)
pipe = HunyuanVideoFramepackPipeline.from_pretrained(
"hunyuanvideo-community/HunyuanVideo",
transformer=transformer,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
torch_dtype=torch.float16,
)
# Enable memory optimizations
pipe.enable_model_cpu_offload()
pipe.vae.enable_tiling()
prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
first_image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png"
)
last_image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png"
)
output = pipe(
image=first_image,
last_image=last_image,
prompt=prompt,
height=512,
width=512,
num_frames=91,
num_inference_steps=30,
guidance_scale=9.0,
generator=torch.Generator().manual_seed(0),
sampling_type="inverted_anti_drifting",
).frames[0]
export_to_video(output, "output.mp4", fps=30)
```
### Vanilla sampling
The following example shows how to use Framepack with the F1 model trained with vanilla sampling but new regulation approach for anti-drifting.
```python
import torch
from diffusers import HunyuanVideoFramepackPipeline, HunyuanVideoFramepackTransformer3DModel
from diffusers.utils import export_to_video, load_image
from transformers import SiglipImageProcessor, SiglipVisionModel
transformer = HunyuanVideoFramepackTransformer3DModel.from_pretrained(
"lllyasviel/FramePack_F1_I2V_HY_20250503", torch_dtype=torch.bfloat16
)
feature_extractor = SiglipImageProcessor.from_pretrained(
"lllyasviel/flux_redux_bfl", subfolder="feature_extractor"
)
image_encoder = SiglipVisionModel.from_pretrained(
"lllyasviel/flux_redux_bfl", subfolder="image_encoder", torch_dtype=torch.float16
)
pipe = HunyuanVideoFramepackPipeline.from_pretrained(
"hunyuanvideo-community/HunyuanVideo",
transformer=transformer,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
torch_dtype=torch.float16,
)
# Enable memory optimizations
pipe.enable_model_cpu_offload()
pipe.vae.enable_tiling()
image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/penguin.png"
)
output = pipe(
image=image,
prompt="A penguin dancing in the snow",
height=832,
width=480,
num_frames=91,
num_inference_steps=30,
guidance_scale=9.0,
generator=torch.Generator().manual_seed(0),
sampling_type="vanilla",
).frames[0]
export_to_video(output, "output.mp4", fps=30)
```
### Group offloading
Group offloading ([`~hooks.apply_group_offloading`]) provides aggressive memory optimizations for offloading internal parts of any model to the CPU, with possibly no additional overhead to generation time. If you have very low VRAM available, this approach may be suitable for you depending on the amount of CPU RAM available.
```python
import torch
from diffusers import HunyuanVideoFramepackPipeline, HunyuanVideoFramepackTransformer3DModel
from diffusers.hooks import apply_group_offloading
from diffusers.utils import export_to_video, load_image
from transformers import SiglipImageProcessor, SiglipVisionModel
transformer = HunyuanVideoFramepackTransformer3DModel.from_pretrained(
"lllyasviel/FramePack_F1_I2V_HY_20250503", torch_dtype=torch.bfloat16
)
feature_extractor = SiglipImageProcessor.from_pretrained(
"lllyasviel/flux_redux_bfl", subfolder="feature_extractor"
)
image_encoder = SiglipVisionModel.from_pretrained(
"lllyasviel/flux_redux_bfl", subfolder="image_encoder", torch_dtype=torch.float16
)
pipe = HunyuanVideoFramepackPipeline.from_pretrained(
"hunyuanvideo-community/HunyuanVideo",
transformer=transformer,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
torch_dtype=torch.float16,
)
# Enable group offloading
onload_device = torch.device("cuda")
offload_device = torch.device("cpu")
list(map(
lambda x: apply_group_offloading(x, onload_device, offload_device, offload_type="leaf_level", use_stream=True, low_cpu_mem_usage=True),
[pipe.text_encoder, pipe.text_encoder_2, pipe.transformer]
))
pipe.image_encoder.to(onload_device)
pipe.vae.to(onload_device)
pipe.vae.enable_tiling()
image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/penguin.png"
)
output = pipe(
image=image,
prompt="A penguin dancing in the snow",
height=832,
width=480,
num_frames=91,
num_inference_steps=30,
guidance_scale=9.0,
generator=torch.Generator().manual_seed(0),
sampling_type="vanilla",
).frames[0]
print(f"Max memory: {torch.cuda.max_memory_allocated() / 1024**3:.3f} GB")
export_to_video(output, "output.mp4", fps=30)
```
## HunyuanVideoFramepackPipeline
[[autodoc]] HunyuanVideoFramepackPipeline
- all
- __call__
## HunyuanVideoPipelineOutput
[[autodoc]] pipelines.hunyuan_video.pipeline_output.HunyuanVideoPipelineOutput

View File

@@ -52,7 +52,6 @@ The following models are available for the image-to-video pipeline:
| [`Skywork/SkyReels-V1-Hunyuan-I2V`](https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-I2V) | Skywork's custom finetune of HunyuanVideo (de-distilled). Performs best with `97x544x960` resolution. Performs best at `97x544x960` resolution, `guidance_scale=1.0`, `true_cfg_scale=6.0` and a negative prompt. |
| [`hunyuanvideo-community/HunyuanVideo-I2V-33ch`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo 33-channel I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20). |
| [`hunyuanvideo-community/HunyuanVideo-I2V`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo 16-channel I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20) |
- [`lllyasviel/FramePackI2V_HY`](https://huggingface.co/lllyasviel/FramePackI2V_HY) | lllyasviel's paper introducing a new technique for long-context video generation called [Framepack](https://arxiv.org/abs/2504.12626). |
## Quantization

View File

@@ -14,6 +14,7 @@
import inspect
import math
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
@@ -91,6 +92,7 @@ EXAMPLE_DOC_STRING = """
... num_inference_steps=30,
... guidance_scale=9.0,
... generator=torch.Generator().manual_seed(0),
... sampling_type="inverted_anti_drifting",
... ).frames[0]
>>> export_to_video(output, "output.mp4", fps=30)
```
@@ -138,6 +140,7 @@ EXAMPLE_DOC_STRING = """
... num_inference_steps=30,
... guidance_scale=9.0,
... generator=torch.Generator().manual_seed(0),
... sampling_type="inverted_anti_drifting",
... ).frames[0]
>>> export_to_video(output, "output.mp4", fps=30)
```
@@ -232,6 +235,11 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
class FramepackSamplingType(str, Enum):
VANILLA = "vanilla"
INVERTED_ANTI_DRIFTING = "inverted_anti_drifting"
class HunyuanVideoFramepackPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
r"""
Pipeline for text-to-video generation using HunyuanVideo.
@@ -455,6 +463,11 @@ class HunyuanVideoFramepackPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMix
prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
prompt_template=None,
image=None,
image_latents=None,
last_image=None,
last_image_latents=None,
sampling_type=None,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
@@ -493,6 +506,21 @@ class HunyuanVideoFramepackPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMix
f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}"
)
sampling_types = [x.value for x in FramepackSamplingType.__members__.values()]
if sampling_type not in sampling_types:
raise ValueError(f"`sampling_type` has to be one of '{sampling_types}' but is '{sampling_type}'")
if image is not None and image_latents is not None:
raise ValueError("Only one of `image` or `image_latents` can be passed.")
if last_image is not None and last_image_latents is not None:
raise ValueError("Only one of `last_image` or `last_image_latents` can be passed.")
if sampling_type != FramepackSamplingType.INVERTED_ANTI_DRIFTING and (
last_image is not None or last_image_latents is not None
):
raise ValueError(
'Only `"inverted_anti_drifting"` inference type supports `last_image` or `last_image_latents`.'
)
def prepare_latents(
self,
batch_size: int = 1,
@@ -623,6 +651,7 @@ class HunyuanVideoFramepackPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMix
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
max_sequence_length: int = 256,
sampling_type: FramepackSamplingType = FramepackSamplingType.INVERTED_ANTI_DRIFTING,
):
r"""
The call function to the pipeline for generation.
@@ -735,6 +764,11 @@ class HunyuanVideoFramepackPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMix
prompt_embeds,
callback_on_step_end_tensor_inputs,
prompt_template,
image,
image_latents,
last_image,
last_image_latents,
sampling_type,
)
has_neg_prompt = negative_prompt is not None or (
@@ -806,18 +840,6 @@ class HunyuanVideoFramepackPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMix
num_channels_latents = self.transformer.config.in_channels
window_num_frames = (latent_window_size - 1) * self.vae_scale_factor_temporal + 1
num_latent_sections = max(1, (num_frames + window_num_frames - 1) // window_num_frames)
# Specific to the released checkpoint: https://huggingface.co/lllyasviel/FramePackI2V_HY
# TODO: find a more generic way in future if there are more checkpoints
history_sizes = [1, 2, 16]
history_latents = torch.zeros(
batch_size,
num_channels_latents,
sum(history_sizes),
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial,
device=device,
dtype=torch.float32,
)
history_video = None
total_generated_latent_frames = 0
@@ -829,38 +851,92 @@ class HunyuanVideoFramepackPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMix
last_image, dtype=torch.float32, device=device, generator=generator
)
latent_paddings = list(reversed(range(num_latent_sections)))
if num_latent_sections > 4:
latent_paddings = [3] + [2] * (num_latent_sections - 3) + [1, 0]
# Specific to the released checkpoints:
# - https://huggingface.co/lllyasviel/FramePackI2V_HY
# - https://huggingface.co/lllyasviel/FramePack_F1_I2V_HY_20250503
# TODO: find a more generic way in future if there are more checkpoints
if sampling_type == FramepackSamplingType.INVERTED_ANTI_DRIFTING:
history_sizes = [1, 2, 16]
history_latents = torch.zeros(
batch_size,
num_channels_latents,
sum(history_sizes),
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial,
device=device,
dtype=torch.float32,
)
elif sampling_type == FramepackSamplingType.VANILLA:
history_sizes = [16, 2, 1]
history_latents = torch.zeros(
batch_size,
num_channels_latents,
sum(history_sizes),
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial,
device=device,
dtype=torch.float32,
)
history_latents = torch.cat([history_latents, image_latents], dim=2)
total_generated_latent_frames += 1
else:
assert False
# 6. Prepare guidance condition
guidance = torch.tensor([guidance_scale] * batch_size, dtype=transformer_dtype, device=device) * 1000.0
# 7. Denoising loop
for k in range(num_latent_sections):
is_first_section = k == 0
is_last_section = k == num_latent_sections - 1
latent_padding_size = latent_paddings[k] * latent_window_size
if sampling_type == FramepackSamplingType.INVERTED_ANTI_DRIFTING:
latent_paddings = list(reversed(range(num_latent_sections)))
if num_latent_sections > 4:
latent_paddings = [3] + [2] * (num_latent_sections - 3) + [1, 0]
indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, *history_sizes]))
(
indices_prefix,
indices_padding,
indices_latents,
indices_postfix,
indices_latents_history_2x,
indices_latents_history_4x,
) = indices.split([1, latent_padding_size, latent_window_size, *history_sizes], dim=0)
# Inverted anti-drifting sampling: Figure 2(c) in the paper
indices_clean_latents = torch.cat([indices_prefix, indices_postfix], dim=0)
is_first_section = k == 0
is_last_section = k == num_latent_sections - 1
latent_padding_size = latent_paddings[k] * latent_window_size
latents_prefix = image_latents
latents_postfix, latents_history_2x, latents_history_4x = history_latents[
:, :, : sum(history_sizes)
].split(history_sizes, dim=2)
if last_image is not None and is_first_section:
latents_postfix = last_image_latents
latents_clean = torch.cat([latents_prefix, latents_postfix], dim=2)
indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, *history_sizes]))
(
indices_prefix,
indices_padding,
indices_latents,
indices_latents_history_1x,
indices_latents_history_2x,
indices_latents_history_4x,
) = indices.split([1, latent_padding_size, latent_window_size, *history_sizes], dim=0)
# Inverted anti-drifting sampling: Figure 2(c) in the paper
indices_clean_latents = torch.cat([indices_prefix, indices_latents_history_1x], dim=0)
latents_prefix = image_latents
latents_history_1x, latents_history_2x, latents_history_4x = history_latents[
:, :, : sum(history_sizes)
].split(history_sizes, dim=2)
if last_image is not None and is_first_section:
latents_history_1x = last_image_latents
latents_clean = torch.cat([latents_prefix, latents_history_1x], dim=2)
elif sampling_type == FramepackSamplingType.VANILLA:
indices = torch.arange(0, sum([1, *history_sizes, latent_window_size]))
(
indices_prefix,
indices_latents_history_4x,
indices_latents_history_2x,
indices_latents_history_1x,
indices_latents,
) = indices.split([1, *history_sizes, latent_window_size], dim=0)
indices_clean_latents = torch.cat([indices_prefix, indices_latents_history_1x], dim=0)
latents_prefix = image_latents
latents_history_4x, latents_history_2x, latents_history_1x = history_latents[
:, :, -sum(history_sizes) :
].split(history_sizes, dim=2)
latents_clean = torch.cat([latents_prefix, latents_history_1x], dim=2)
else:
assert False
latents = self.prepare_latents(
batch_size,
@@ -960,13 +1036,26 @@ class HunyuanVideoFramepackPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMix
if XLA_AVAILABLE:
xm.mark_step()
if is_last_section:
latents = torch.cat([image_latents, latents], dim=2)
if sampling_type == FramepackSamplingType.INVERTED_ANTI_DRIFTING:
if is_last_section:
latents = torch.cat([image_latents, latents], dim=2)
total_generated_latent_frames += latents.shape[2]
history_latents = torch.cat([latents, history_latents], dim=2)
real_history_latents = history_latents[:, :, :total_generated_latent_frames]
section_latent_frames = (
(latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2)
)
index_slice = (slice(None), slice(None), slice(0, section_latent_frames))
total_generated_latent_frames += latents.shape[2]
history_latents = torch.cat([latents, history_latents], dim=2)
elif sampling_type == FramepackSamplingType.VANILLA:
total_generated_latent_frames += latents.shape[2]
history_latents = torch.cat([history_latents, latents], dim=2)
real_history_latents = history_latents[:, :, -total_generated_latent_frames:]
section_latent_frames = latent_window_size * 2
index_slice = (slice(None), slice(None), slice(-section_latent_frames, None))
real_history_latents = history_latents[:, :, :total_generated_latent_frames]
else:
assert False
if history_video is None:
if not output_type == "latent":
@@ -976,16 +1065,18 @@ class HunyuanVideoFramepackPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMix
history_video = [real_history_latents]
else:
if not output_type == "latent":
section_latent_frames = (
(latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2)
)
overlapped_frames = (latent_window_size - 1) * self.vae_scale_factor_temporal + 1
current_latents = (
real_history_latents[:, :, :section_latent_frames].to(vae_dtype)
/ self.vae.config.scaling_factor
real_history_latents[index_slice].to(vae_dtype) / self.vae.config.scaling_factor
)
current_video = self.vae.decode(current_latents, return_dict=False)[0]
history_video = self._soft_append(current_video, history_video, overlapped_frames)
if sampling_type == FramepackSamplingType.INVERTED_ANTI_DRIFTING:
history_video = self._soft_append(current_video, history_video, overlapped_frames)
elif sampling_type == FramepackSamplingType.VANILLA:
history_video = self._soft_append(history_video, current_video, overlapped_frames)
else:
assert False
else:
history_video.append(real_history_latents)