Compare commits

...

5 Commits

Author SHA1 Message Date
cbensimon
bc8817604a [core] Add export-safe LRU cache helper 2026-03-19 17:30:58 +00:00
kaixuanliu
67613369bb fix: 'PaintByExampleImageEncoder' object has no attribute 'all_tied_w… (#13252)
* fix: 'PaintByExampleImageEncoder' object has no attribute 'all_tied_weights_keys'

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* also fix LDMBertModel

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

---------

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2026-03-18 17:55:08 -10:00
Shenghai Yuan
0c01a4b5e2 [Helios] Remove lru_cache for better AoTI compatibility and cleaner code (#13282)
fix: drop lru_cache for better AoTI compatibility
2026-03-18 23:41:58 +05:30
kaixuanliu
8e4b5607ed skip invalid test case for helios pipeline (#13218)
* skip invalid test case for helio pipeline

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update skip reason

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

---------

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
2026-03-17 20:58:35 -10:00
Junsong Chen
c6f72ad2f6 add ltx2 vae in sana-video; (#13229)
* add ltx2 vae in sana-video;

* add ltx vae in conversion script;

* Update src/diffusers/pipelines/sana_video/pipeline_sana_video.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/sana_video/pipeline_sana_video.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* condition `vae_scale_factor_xxx` related settings on VAE types;

* make the mean/std depends on vae class;

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
2026-03-17 18:09:52 -10:00
11 changed files with 149 additions and 56 deletions

View File

@@ -12,6 +12,7 @@ from termcolor import colored
from transformers import AutoModelForCausalLM, AutoTokenizer
from diffusers import (
AutoencoderKLLTX2Video,
AutoencoderKLWan,
DPMSolverMultistepScheduler,
FlowMatchEulerDiscreteScheduler,
@@ -24,7 +25,10 @@ from diffusers.utils.import_utils import is_accelerate_available
CTX = init_empty_weights if is_accelerate_available else nullcontext
ckpt_ids = ["Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth"]
ckpt_ids = [
"Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth",
"Efficient-Large-Model/SANA-Video_2B_720p/checkpoints/SANA_Video_2B_720p_LTXVAE.pth",
]
# https://github.com/NVlabs/Sana/blob/main/inference_video_scripts/inference_sana_video.py
@@ -92,12 +96,22 @@ def main(args):
if args.video_size == 480:
sample_size = 30 # Wan-VAE: 8xp2 downsample factor
patch_size = (1, 2, 2)
in_channels = 16
out_channels = 16
elif args.video_size == 720:
sample_size = 22 # Wan-VAE: 32xp1 downsample factor
sample_size = 22 # DC-AE-V: 32xp1 downsample factor
patch_size = (1, 1, 1)
in_channels = 32
out_channels = 32
else:
raise ValueError(f"Video size {args.video_size} is not supported.")
if args.vae_type == "ltx2":
sample_size = 22
patch_size = (1, 1, 1)
in_channels = 128
out_channels = 128
for depth in range(layer_num):
# Transformer blocks.
converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
@@ -182,8 +196,8 @@ def main(args):
# Transformer
with CTX():
transformer_kwargs = {
"in_channels": 16,
"out_channels": 16,
"in_channels": in_channels,
"out_channels": out_channels,
"num_attention_heads": 20,
"attention_head_dim": 112,
"num_layers": 20,
@@ -235,9 +249,12 @@ def main(args):
else:
print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"]))
# VAE
vae = AutoencoderKLWan.from_pretrained(
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
)
if args.vae_type == "ltx2":
vae_path = args.vae_path or "Lightricks/LTX-2"
vae = AutoencoderKLLTX2Video.from_pretrained(vae_path, subfolder="vae", torch_dtype=torch.float32)
else:
vae_path = args.vae_path or "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
vae = AutoencoderKLWan.from_pretrained(vae_path, subfolder="vae", torch_dtype=torch.float32)
# Text Encoder
text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it"
@@ -314,7 +331,23 @@ if __name__ == "__main__":
choices=["flow-dpm_solver", "flow-euler", "uni-pc"],
help="Scheduler type to use.",
)
parser.add_argument("--task", default="t2v", type=str, required=True, help="Task to convert, t2v or i2v.")
parser.add_argument(
"--vae_type",
default="wan",
type=str,
choices=["wan", "ltx2"],
help="VAE type to use for saving full pipeline (ltx2 uses patchify 1x1x1).",
)
parser.add_argument(
"--vae_path",
default=None,
type=str,
required=False,
help="Optional VAE path or repo id. If not set, a default is used per VAE type.",
)
parser.add_argument(
"--task", default="t2v", type=str, required=True, choices=["t2v", "i2v"], help="Task to convert, t2v or i2v."
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipeline elements in one.")
parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")

View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import functools
import inspect
from dataclasses import dataclass
from typing import Type
@@ -32,7 +31,7 @@ from ..models._modeling_parallel import (
gather_size_by_comm,
)
from ..utils import get_logger
from ..utils.torch_utils import maybe_allow_in_graph, unwrap_module
from ..utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph, unwrap_module
from .hooks import HookRegistry, ModelHook
@@ -327,7 +326,7 @@ class PartitionAnythingSharder:
return tensor
@functools.lru_cache(maxsize=64)
@lru_cache_unless_export(maxsize=64)
def _fill_gather_shapes(shape: tuple[int], gather_dims: tuple[int], dim: int, world_size: int) -> list[list[int]]:
gather_shapes = []
for i in range(world_size):

View File

@@ -49,7 +49,7 @@ from ..utils import (
is_xformers_version,
)
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
from ..utils.torch_utils import maybe_allow_in_graph
from ..utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph
from ._modeling_parallel import gather_size_by_comm
@@ -575,7 +575,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
)
@functools.lru_cache(maxsize=128)
@lru_cache_unless_export(maxsize=128)
def _prepare_for_flash_attn_or_sage_varlen_without_mask(
batch_size: int,
seq_len_q: int,

View File

@@ -13,7 +13,6 @@
# limitations under the License.
import math
from functools import lru_cache
from typing import Any
import torch
@@ -343,7 +342,6 @@ class HeliosRotaryPosEmbed(nn.Module):
return freqs.cos(), freqs.sin()
@torch.no_grad()
@lru_cache(maxsize=32)
def _get_spatial_meshgrid(self, height, width, device_str):
device = torch.device(device_str)
grid_y_coords = torch.arange(height, device=device, dtype=torch.float32)

View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import math
from math import prod
from typing import Any
@@ -25,7 +24,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import apply_lora_scale, deprecate, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ...utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
@@ -307,7 +306,7 @@ class QwenEmbedRope(nn.Module):
return vid_freqs, txt_freqs
@functools.lru_cache(maxsize=128)
@lru_cache_unless_export(maxsize=128)
def _compute_video_freqs(
self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None
) -> torch.Tensor:
@@ -428,7 +427,7 @@ class QwenEmbedLayer3DRope(nn.Module):
return vid_freqs, txt_freqs
@functools.lru_cache(maxsize=None)
@lru_cache_unless_export(maxsize=None)
def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None):
seq_lens = frame * height * width
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
@@ -450,7 +449,7 @@ class QwenEmbedLayer3DRope(nn.Module):
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
return freqs.clone().contiguous()
@functools.lru_cache(maxsize=None)
@lru_cache_unless_export(maxsize=None)
def _compute_condition_freqs(self, frame, height, width, device: torch.device = None):
seq_lens = frame * height * width
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs

View File

@@ -720,6 +720,7 @@ class LDMBertModel(LDMBertPreTrainedModel):
super().__init__(config)
self.model = LDMBertEncoder(config)
self.to_logits = nn.Linear(config.hidden_size, config.vocab_size)
self.post_init()
def forward(
self,

View File

@@ -35,6 +35,8 @@ class PaintByExampleImageEncoder(CLIPPreTrainedModel):
# uncondition for scaling
self.uncond_vector = nn.Parameter(torch.randn((1, 1, self.proj_size)))
self.post_init()
def forward(self, pixel_values, return_uncond_vector=False):
clip_output = self.model(pixel_values=pixel_values)
latent_states = clip_output.pooler_output

View File

@@ -24,7 +24,7 @@ from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFa
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...loaders import SanaLoraLoaderMixin
from ...models import AutoencoderDC, AutoencoderKLWan, SanaVideoTransformer3DModel
from ...models import AutoencoderDC, AutoencoderKLLTX2Video, AutoencoderKLWan, SanaVideoTransformer3DModel
from ...schedulers import DPMSolverMultistepScheduler
from ...utils import (
BACKENDS_MAPPING,
@@ -194,7 +194,7 @@ class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
The tokenizer used to tokenize the prompt.
text_encoder ([`Gemma2PreTrainedModel`]):
Text encoder model to encode the input prompts.
vae ([`AutoencoderKLWan` or `AutoencoderDCAEV`]):
vae ([`AutoencoderKLWan`, `AutoencoderDC`, or `AutoencoderKLLTX2Video`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
transformer ([`SanaVideoTransformer3DModel`]):
Conditional Transformer to denoise the input latents.
@@ -213,7 +213,7 @@ class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
self,
tokenizer: GemmaTokenizer | GemmaTokenizerFast,
text_encoder: Gemma2PreTrainedModel,
vae: AutoencoderDC | AutoencoderKLWan,
vae: AutoencoderDC | AutoencoderKLLTX2Video | AutoencoderKLWan,
transformer: SanaVideoTransformer3DModel,
scheduler: DPMSolverMultistepScheduler,
):
@@ -223,8 +223,19 @@ class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
if getattr(self, "vae", None):
if isinstance(self.vae, AutoencoderKLLTX2Video):
self.vae_scale_factor_temporal = self.vae.config.temporal_compression_ratio
self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio
elif isinstance(self.vae, (AutoencoderDC, AutoencoderKLWan)):
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial
else:
self.vae_scale_factor_temporal = 4
self.vae_scale_factor_spatial = 8
else:
self.vae_scale_factor_temporal = 4
self.vae_scale_factor_spatial = 8
self.vae_scale_factor = self.vae_scale_factor_spatial
@@ -985,14 +996,21 @@ class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
if is_torch_version(">=", "2.5.0")
else torch_accelerator_module.OutOfMemoryError
)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
latents.device, latents.dtype
)
if isinstance(self.vae, AutoencoderKLLTX2Video):
latents_mean = self.vae.latents_mean
latents_std = self.vae.latents_std
z_dim = self.vae.config.latent_channels
elif isinstance(self.vae, AutoencoderKLWan):
latents_mean = torch.tensor(self.vae.config.latents_mean)
latents_std = torch.tensor(self.vae.config.latents_std)
z_dim = self.vae.config.z_dim
else:
latents_mean = torch.zeros(latents.shape[1], device=latents.device, dtype=latents.dtype)
latents_std = torch.ones(latents.shape[1], device=latents.device, dtype=latents.dtype)
z_dim = latents.shape[1]
latents_mean = latents_mean.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = 1.0 / latents_std.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
latents = latents / latents_std + latents_mean
try:
video = self.vae.decode(latents, return_dict=False)[0]

View File

@@ -26,7 +26,7 @@ from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFa
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput
from ...loaders import SanaLoraLoaderMixin
from ...models import AutoencoderDC, AutoencoderKLWan, SanaVideoTransformer3DModel
from ...models import AutoencoderDC, AutoencoderKLLTX2Video, AutoencoderKLWan, SanaVideoTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
BACKENDS_MAPPING,
@@ -184,7 +184,7 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
The tokenizer used to tokenize the prompt.
text_encoder ([`Gemma2PreTrainedModel`]):
Text encoder model to encode the input prompts.
vae ([`AutoencoderKLWan` or `AutoencoderDCAEV`]):
vae ([`AutoencoderKLWan`, `AutoencoderDC`, or `AutoencoderKLLTX2Video`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
transformer ([`SanaVideoTransformer3DModel`]):
Conditional Transformer to denoise the input latents.
@@ -203,7 +203,7 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
self,
tokenizer: GemmaTokenizer | GemmaTokenizerFast,
text_encoder: Gemma2PreTrainedModel,
vae: AutoencoderDC | AutoencoderKLWan,
vae: AutoencoderDC | AutoencoderKLLTX2Video | AutoencoderKLWan,
transformer: SanaVideoTransformer3DModel,
scheduler: FlowMatchEulerDiscreteScheduler,
):
@@ -213,8 +213,19 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
if getattr(self, "vae", None):
if isinstance(self.vae, AutoencoderKLLTX2Video):
self.vae_scale_factor_temporal = self.vae.config.temporal_compression_ratio
self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio
elif isinstance(self.vae, (AutoencoderDC, AutoencoderKLWan)):
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial
else:
self.vae_scale_factor_temporal = 4
self.vae_scale_factor_spatial = 8
else:
self.vae_scale_factor_temporal = 4
self.vae_scale_factor_spatial = 8
self.vae_scale_factor = self.vae_scale_factor_spatial
@@ -687,14 +698,18 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax")
image_latents = image_latents.repeat(batch_size, 1, 1, 1, 1)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, -1, 1, 1, 1)
.to(image_latents.device, image_latents.dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to(
image_latents.device, image_latents.dtype
)
if isinstance(self.vae, AutoencoderKLLTX2Video):
_latents_mean = self.vae.latents_mean
_latents_std = self.vae.latents_std
elif isinstance(self.vae, AutoencoderKLWan):
_latents_mean = torch.tensor(self.vae.config.latents_mean)
_latents_std = torch.tensor(self.vae.config.latents_std)
else:
_latents_mean = torch.zeros(image_latents.shape[1], device=image_latents.device, dtype=image_latents.dtype)
_latents_std = torch.ones(image_latents.shape[1], device=image_latents.device, dtype=image_latents.dtype)
latents_mean = _latents_mean.view(1, -1, 1, 1, 1).to(image_latents.device, image_latents.dtype)
latents_std = 1.0 / _latents_std.view(1, -1, 1, 1, 1).to(image_latents.device, image_latents.dtype)
image_latents = (image_latents - latents_mean) * latents_std
latents[:, :, 0:1] = image_latents.to(dtype)
@@ -1034,14 +1049,21 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
if is_torch_version(">=", "2.5.0")
else torch_accelerator_module.OutOfMemoryError
)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
latents.device, latents.dtype
)
if isinstance(self.vae, AutoencoderKLLTX2Video):
latents_mean = self.vae.latents_mean
latents_std = self.vae.latents_std
z_dim = self.vae.config.latent_channels
elif isinstance(self.vae, AutoencoderKLWan):
latents_mean = torch.tensor(self.vae.config.latents_mean)
latents_std = torch.tensor(self.vae.config.latents_std)
z_dim = self.vae.config.z_dim
else:
latents_mean = torch.zeros(latents.shape[1], device=latents.device, dtype=latents.dtype)
latents_std = torch.ones(latents.shape[1], device=latents.device, dtype=latents.dtype)
z_dim = latents.shape[1]
latents_mean = latents_mean.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = 1.0 / latents_std.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
latents = latents / latents_std + latents_mean
try:
video = self.vae.decode(latents, return_dict=False)[0]

View File

@@ -19,11 +19,16 @@ from __future__ import annotations
import functools
import os
from typing import Callable, ParamSpec, TypeVar
from . import logging
from .import_utils import is_torch_available, is_torch_mlu_available, is_torch_npu_available, is_torch_version
T = TypeVar("T")
P = ParamSpec("P")
if is_torch_available():
import torch
from torch.fft import fftn, fftshift, ifftn, ifftshift
@@ -333,5 +338,21 @@ def disable_full_determinism():
torch.use_deterministic_algorithms(False)
@functools.wraps(functools.lru_cache)
def lru_cache_unless_export(maxsize=128, typed=False):
def outer_wrapper(fn: Callable[P, T]):
cached = functools.lru_cache(maxsize=maxsize, typed=typed)(fn)
@functools.wraps(fn)
def inner_wrapper(*args: P.args, **kwargs: P.kwargs):
if torch.compiler.is_exporting():
return fn(*args, **kwargs)
return cached(*args, **kwargs)
return inner_wrapper
return outer_wrapper
if is_torch_available():
torch_device = get_device()

View File

@@ -139,9 +139,9 @@ class HeliosPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
# Override to set a more lenient max diff threshold.
@unittest.skip("Helios uses a lot of mixed precision internally, which is not suitable for this test case")
def test_save_load_float16(self):
super().test_save_load_float16(expected_max_diff=0.03)
pass
@unittest.skip("Test not supported")
def test_attention_slicing_forward_pass(self):