SANA-Video Image to Video pipeline SanaImageToVideoPipeline support (#12634)

* move sana-video to a new dir and add `SanaImageToVideoPipeline` with no modify;

* fix bug and run text/image-to-vidoe success;

* make style; quality; fix-copies;

* add sana image-to-video pipeline in markdown;

* add test case for sana image-to-video;

* make style;

* add a init file in sana-video test dir;

* Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update tests/pipelines/sana_video/test_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update tests/pipelines/sana_video/test_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* minor update;

* fix bug and skip fp16 save test;

Co-authored-by: Yuyang Zhao <43061147+HeliosZhao@users.noreply.github.com>

* Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* add copied from for `encode_prompt`

* Apply style fixes

---------

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
Co-authored-by: Yuyang Zhao <43061147+HeliosZhao@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Junsong Chen
2025-11-17 16:23:34 +08:00
committed by GitHub
parent 0c35b580fe
commit 1afc21855e
15 changed files with 1501 additions and 34 deletions

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License. -->
# SanaVideoPipeline
# Sana-Video
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
@@ -37,6 +37,85 @@ Refer to [this](https://huggingface.co/collections/Efficient-Large-Model/sana-vi
Note: The recommended dtype mentioned is for the transformer weights. The text encoder and VAE weights must stay in `torch.bfloat16` or `torch.float32` for the model to work correctly. Please refer to the inference example below to see how to load the model with the recommended dtype.
## Generation Pipelines
<hfoptions id="generation pipelines">`
<hfoption id="Text-to-Video">
The example below demonstrates how to use the text-to-video pipeline to generate a video using a text descriptio and a starting frame.
```python
model_id =
pipe = SanaVideoPipeline.from_pretrained("Efficient-Large-Model/SANA-Video_2B_480p_diffusers", torch_dtype=torch.bfloat16)
pipe.text_encoder.to(torch.bfloat16)
pipe.vae.to(torch.float32)
pipe.to("cuda")
prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience."
motion_scale = 30
motion_prompt = f" motion score: {motion_scale}."
prompt = prompt + motion_prompt
video = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=480,
width=832,
frames=81,
guidance_scale=6,
num_inference_steps=50,
generator=torch.Generator(device="cuda").manual_seed(0),
).frames[0]
export_to_video(video, "sana_video.mp4", fps=16)
```
</hfoption>
<hfoption id="Image-to-Video">
The example below demonstrates how to use the image-to-video pipeline to generate a video using a text descriptio and a starting frame.
```python
model_id = "Efficient-Large-Model/SANA-Video_2B_480p_diffusers"
pipe = SanaImageToVideoPipeline.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
)
pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)
pipe.vae.to(torch.float32)
pipe.text_encoder.to(torch.bfloat16)
pipe.to("cuda")
image = load_image("https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/samples/i2v-1.png")
prompt = "A woman stands against a stunning sunset backdrop, her long, wavy brown hair gently blowing in the breeze. She wears a sleeveless, light-colored blouse with a deep V-neckline, which accentuates her graceful posture. The warm hues of the setting sun cast a golden glow across her face and hair, creating a serene and ethereal atmosphere. The background features a blurred landscape with soft, rolling hills and scattered clouds, adding depth to the scene. The camera remains steady, capturing the tranquil moment from a medium close-up angle."
negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience."
motion_scale = 30
motion_prompt = f" motion score: {motion_scale}."
prompt = prompt + motion_prompt
motion_scale = 30.0
video = pipe(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
height=480,
width=832,
frames=81,
guidance_scale=6,
num_inference_steps=50,
generator=torch.Generator(device="cuda").manual_seed(0),
).frames[0]
export_to_video(video, "sana-i2v.mp4", fps=16)
```
</hfoption>
</hfoptions>
## Quantization
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
@@ -97,6 +176,13 @@ export_to_video(output, "sana-video-output.mp4", fps=16)
- __call__
## SanaImageToVideoPipeline
[[autodoc]] SanaImageToVideoPipeline
- all
- __call__
## SanaVideoPipelineOutput
[[autodoc]] pipelines.sana.pipeline_sana_video.SanaVideoPipelineOutput
[[autodoc]] pipelines.sana_video.pipeline_sana_video.SanaVideoPipelineOutput

View File

@@ -80,6 +80,8 @@ def main(args):
# scheduler
flow_shift = 8.0
if args.task == "i2v":
assert args.scheduler_type == "flow-euler", "Scheduler type must be flow-euler for i2v task."
# model config
layer_num = 20
@@ -312,6 +314,7 @@ if __name__ == "__main__":
choices=["flow-dpm_solver", "flow-euler", "uni-pc"],
help="Scheduler type to use.",
)
parser.add_argument("--task", default="t2v", type=str, required=True, help="Task to convert, t2v or i2v.")
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipeline elements in one.")
parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")

View File

@@ -545,11 +545,13 @@ else:
"QwenImagePipeline",
"ReduxImageEncoder",
"SanaControlNetPipeline",
"SanaImageToVideoPipeline",
"SanaPAGPipeline",
"SanaPipeline",
"SanaSprintImg2ImgPipeline",
"SanaSprintPipeline",
"SanaVideoPipeline",
"SanaVideoPipeline",
"SemanticStableDiffusionPipeline",
"ShapEImg2ImgPipeline",
"ShapEPipeline",
@@ -1227,6 +1229,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
QwenImagePipeline,
ReduxImageEncoder,
SanaControlNetPipeline,
SanaImageToVideoPipeline,
SanaPAGPipeline,
SanaPipeline,
SanaSprintImg2ImgPipeline,

View File

@@ -237,7 +237,6 @@ class WanRotaryPosEmbed(nn.Module):
return freqs_cos, freqs_sin
# Copied from diffusers.models.transformers.sana_transformer.SanaModulatedNorm
class SanaModulatedNorm(nn.Module):
def __init__(self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6):
super().__init__()
@@ -247,7 +246,7 @@ class SanaModulatedNorm(nn.Module):
self, hidden_states: torch.Tensor, temb: torch.Tensor, scale_shift_table: torch.Tensor
) -> torch.Tensor:
hidden_states = self.norm(hidden_states)
shift, scale = (scale_shift_table[None] + temb[:, None].to(scale_shift_table.device)).chunk(2, dim=1)
shift, scale = (scale_shift_table[None, None] + temb[:, :, None].to(scale_shift_table.device)).unbind(dim=2)
hidden_states = hidden_states * (1 + scale) + shift
return hidden_states
@@ -423,8 +422,8 @@ class SanaVideoTransformerBlock(nn.Module):
# 1. Modulation
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
).chunk(6, dim=1)
self.scale_shift_table[None, None] + timestep.reshape(batch_size, timestep.shape[1], 6, -1)
).unbind(dim=2)
# 2. Self Attention
norm_hidden_states = self.norm1(hidden_states)
@@ -635,13 +634,16 @@ class SanaVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
if guidance is not None:
timestep, embedded_timestep = self.time_embed(
timestep, guidance=guidance, hidden_dtype=hidden_states.dtype
timestep.flatten(), guidance=guidance, hidden_dtype=hidden_states.dtype
)
else:
timestep, embedded_timestep = self.time_embed(
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
timestep.flatten(), batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
timestep = timestep.view(batch_size, -1, timestep.size(-1))
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1))
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])

View File

@@ -308,7 +308,10 @@ else:
"SanaSprintPipeline",
"SanaControlNetPipeline",
"SanaSprintImg2ImgPipeline",
]
_import_structure["sana_video"] = [
"SanaVideoPipeline",
"SanaImageToVideoPipeline",
]
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
@@ -749,8 +752,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
SanaPipeline,
SanaSprintImg2ImgPipeline,
SanaSprintPipeline,
SanaVideoPipeline,
)
from .sana_video import SanaImageToVideoPipeline, SanaVideoPipeline
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
from .stable_audio import StableAudioPipeline, StableAudioProjectionModel

View File

@@ -26,7 +26,6 @@ else:
_import_structure["pipeline_sana_controlnet"] = ["SanaControlNetPipeline"]
_import_structure["pipeline_sana_sprint"] = ["SanaSprintPipeline"]
_import_structure["pipeline_sana_sprint_img2img"] = ["SanaSprintImg2ImgPipeline"]
_import_structure["pipeline_sana_video"] = ["SanaVideoPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -40,7 +39,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pipeline_sana_controlnet import SanaControlNetPipeline
from .pipeline_sana_sprint import SanaSprintPipeline
from .pipeline_sana_sprint_img2img import SanaSprintImg2ImgPipeline
from .pipeline_sana_video import SanaVideoPipeline
else:
import sys

View File

@@ -3,7 +3,6 @@ from typing import List, Union
import numpy as np
import PIL.Image
import torch
from ...utils import BaseOutput
@@ -20,18 +19,3 @@ class SanaPipelineOutput(BaseOutput):
"""
images: Union[List[PIL.Image.Image], np.ndarray]
@dataclass
class SanaVideoPipelineOutput(BaseOutput):
r"""
Output class for Sana-Video pipelines.
Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
`(batch_size, num_frames, channels, height, width)`.
"""
frames: torch.Tensor

View File

@@ -0,0 +1,49 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_sana_video"] = ["SanaVideoPipeline"]
_import_structure["pipeline_sana_video_i2v"] = ["SanaImageToVideoPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_sana_video import SanaVideoPipeline
from .pipeline_sana_video_i2v import SanaImageToVideoPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,20 @@
from dataclasses import dataclass
import torch
from ...utils import BaseOutput
@dataclass
class SanaVideoPipelineOutput(BaseOutput):
r"""
Output class for Sana-Video pipelines.
Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
`(batch_size, num_frames, channels, height, width)`.
"""
frames: torch.Tensor

View File

@@ -95,17 +95,16 @@ EXAMPLE_DOC_STRING = """
>>> from diffusers import SanaVideoPipeline
>>> from diffusers.utils import export_to_video
>>> model_id = "Efficient-Large-Model/SANA-Video_2B_480p_diffusers"
>>> pipe = SanaVideoPipeline.from_pretrained(model_id)
>>> pipe = SanaVideoPipeline.from_pretrained("Efficient-Large-Model/SANA-Video_2B_480p_diffusers")
>>> pipe.transformer.to(torch.bfloat16)
>>> pipe.text_encoder.to(torch.bfloat16)
>>> pipe.vae.to(torch.float32)
>>> pipe.to("cuda")
>>> model_score = 30
>>> motion_score = 30
>>> prompt = "Evening, backlight, side lighting, soft light, high contrast, mid-shot, centered composition, clean solo shot, warm color. A young Caucasian man stands in a forest, golden light glimmers on his hair as sunlight filters through the leaves. He wears a light shirt, wind gently blowing his hair and collar, light dances across his face with his movements. The background is blurred, with dappled light and soft tree shadows in the distance. The camera focuses on his lifted gaze, clear and emotional."
>>> negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience."
>>> motion_prompt = f" motion score: {model_score}."
>>> motion_prompt = f" motion score: {motion_score}."
>>> prompt = prompt + motion_prompt
>>> output = pipe(
@@ -231,6 +230,7 @@ class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
# Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds
def _get_gemma_prompt_embeds(
self,
prompt: Union[str, List[str]],
@@ -827,9 +827,9 @@ class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
Examples:
Returns:
[`~pipelines.sana.pipeline_output.SanaVideoPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.sana.pipeline_output.SanaVideoPipelineOutput`] is returned,
otherwise a `tuple` is returned where the first element is a list with the generated videos
[`~pipelines.sana_video.pipeline_output.SanaVideoPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.sana_video.pipeline_output.SanaVideoPipelineOutput`] is
returned, otherwise a `tuple` is returned where the first element is a list with the generated videos
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):

File diff suppressed because it is too large Load Diff

View File

@@ -2147,6 +2147,21 @@ class SanaControlNetPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class SanaImageToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class SanaPAGPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

View File

@@ -0,0 +1,238 @@
# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import tempfile
import unittest
import numpy as np
import torch
from PIL import Image
from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer
from diffusers import (
AutoencoderKLWan,
FlowMatchEulerDiscreteScheduler,
SanaImageToVideoPipeline,
SanaVideoTransformer3DModel,
)
from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_torch_accelerator,
slow,
torch_device,
)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
class SanaImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = SanaImageToVideoPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
test_xformers_attention = False
supports_dduf = False
def get_dummy_components(self):
torch.manual_seed(0)
vae = AutoencoderKLWan(
base_dim=3,
z_dim=16,
dim_mult=[1, 1, 1, 1],
num_res_blocks=1,
temperal_downsample=[False, True, True],
)
torch.manual_seed(0)
scheduler = FlowMatchEulerDiscreteScheduler()
torch.manual_seed(0)
text_encoder_config = Gemma2Config(
head_dim=16,
hidden_size=8,
initializer_range=0.02,
intermediate_size=64,
max_position_embeddings=8192,
model_type="gemma2",
num_attention_heads=2,
num_hidden_layers=1,
num_key_value_heads=2,
vocab_size=8,
attn_implementation="eager",
)
text_encoder = Gemma2Model(text_encoder_config)
tokenizer = GemmaTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")
torch.manual_seed(0)
transformer = SanaVideoTransformer3DModel(
in_channels=16,
out_channels=16,
num_attention_heads=2,
attention_head_dim=12,
num_layers=2,
num_cross_attention_heads=2,
cross_attention_head_dim=12,
cross_attention_dim=24,
caption_channels=8,
mlp_ratio=2.5,
dropout=0.0,
attention_bias=False,
sample_size=8,
patch_size=(1, 2, 2),
norm_elementwise_affine=False,
norm_eps=1e-6,
qk_norm="rms_norm_across_heads",
rope_max_seq_len=32,
)
components = {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
# Create a dummy image input (PIL Image)
image = Image.new("RGB", (32, 32))
inputs = {
"image": image,
"prompt": "",
"negative_prompt": "",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
"height": 32,
"width": 32,
"frames": 9,
"max_sequence_length": 16,
"output_type": "pt",
"complex_human_instruction": [],
"use_resolution_binning": False,
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
self.assertEqual(generated_video.shape, (9, 3, 32, 32))
@unittest.skip("Test not supported")
def test_attention_slicing_forward_pass(self):
pass
def test_save_load_local(self, expected_max_difference=5e-4):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
torch.manual_seed(0)
output = pipe(**inputs)[0]
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, safe_serialization=False)
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
for component in pipe_loaded.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
torch.manual_seed(0)
output_loaded = pipe_loaded(**inputs)[0]
max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max()
self.assertLess(max_diff, expected_max_difference)
# TODO(aryan): Create a dummy gemma model with smol vocab size
@unittest.skip(
"A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
)
def test_inference_batch_consistent(self):
pass
@unittest.skip(
"A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
)
def test_inference_batch_single_identical(self):
pass
@unittest.skip("Skipping fp16 test as model is trained with bf16")
def test_float16_inference(self):
# Requires higher tolerance as model seems very sensitive to dtype
super().test_float16_inference(expected_max_diff=0.08)
@unittest.skip("Skipping fp16 test as model is trained with bf16")
def test_save_load_float16(self):
# Requires higher tolerance as model seems very sensitive to dtype
super().test_save_load_float16(expected_max_diff=0.2)
@slow
@require_torch_accelerator
class SanaVideoPipelineIntegrationTests(unittest.TestCase):
prompt = "Evening, backlight, side lighting, soft light, high contrast, mid-shot, centered composition, clean solo shot, warm color. A young Caucasian man stands in a forest."
def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
@unittest.skip("TODO: test needs to be implemented")
def test_sana_video_480p(self):
pass