mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-21 07:57:54 +08:00
Compare commits
5 Commits
type-hint-
...
export-saf
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bc8817604a | ||
|
|
67613369bb | ||
|
|
0c01a4b5e2 | ||
|
|
8e4b5607ed | ||
|
|
c6f72ad2f6 |
@@ -12,6 +12,7 @@ from termcolor import colored
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLLTX2Video,
|
||||
AutoencoderKLWan,
|
||||
DPMSolverMultistepScheduler,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
@@ -24,7 +25,10 @@ from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
CTX = init_empty_weights if is_accelerate_available else nullcontext
|
||||
|
||||
ckpt_ids = ["Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth"]
|
||||
ckpt_ids = [
|
||||
"Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth",
|
||||
"Efficient-Large-Model/SANA-Video_2B_720p/checkpoints/SANA_Video_2B_720p_LTXVAE.pth",
|
||||
]
|
||||
# https://github.com/NVlabs/Sana/blob/main/inference_video_scripts/inference_sana_video.py
|
||||
|
||||
|
||||
@@ -92,12 +96,22 @@ def main(args):
|
||||
if args.video_size == 480:
|
||||
sample_size = 30 # Wan-VAE: 8xp2 downsample factor
|
||||
patch_size = (1, 2, 2)
|
||||
in_channels = 16
|
||||
out_channels = 16
|
||||
elif args.video_size == 720:
|
||||
sample_size = 22 # Wan-VAE: 32xp1 downsample factor
|
||||
sample_size = 22 # DC-AE-V: 32xp1 downsample factor
|
||||
patch_size = (1, 1, 1)
|
||||
in_channels = 32
|
||||
out_channels = 32
|
||||
else:
|
||||
raise ValueError(f"Video size {args.video_size} is not supported.")
|
||||
|
||||
if args.vae_type == "ltx2":
|
||||
sample_size = 22
|
||||
patch_size = (1, 1, 1)
|
||||
in_channels = 128
|
||||
out_channels = 128
|
||||
|
||||
for depth in range(layer_num):
|
||||
# Transformer blocks.
|
||||
converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
|
||||
@@ -182,8 +196,8 @@ def main(args):
|
||||
# Transformer
|
||||
with CTX():
|
||||
transformer_kwargs = {
|
||||
"in_channels": 16,
|
||||
"out_channels": 16,
|
||||
"in_channels": in_channels,
|
||||
"out_channels": out_channels,
|
||||
"num_attention_heads": 20,
|
||||
"attention_head_dim": 112,
|
||||
"num_layers": 20,
|
||||
@@ -235,9 +249,12 @@ def main(args):
|
||||
else:
|
||||
print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"]))
|
||||
# VAE
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
|
||||
)
|
||||
if args.vae_type == "ltx2":
|
||||
vae_path = args.vae_path or "Lightricks/LTX-2"
|
||||
vae = AutoencoderKLLTX2Video.from_pretrained(vae_path, subfolder="vae", torch_dtype=torch.float32)
|
||||
else:
|
||||
vae_path = args.vae_path or "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
|
||||
vae = AutoencoderKLWan.from_pretrained(vae_path, subfolder="vae", torch_dtype=torch.float32)
|
||||
|
||||
# Text Encoder
|
||||
text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it"
|
||||
@@ -314,7 +331,23 @@ if __name__ == "__main__":
|
||||
choices=["flow-dpm_solver", "flow-euler", "uni-pc"],
|
||||
help="Scheduler type to use.",
|
||||
)
|
||||
parser.add_argument("--task", default="t2v", type=str, required=True, help="Task to convert, t2v or i2v.")
|
||||
parser.add_argument(
|
||||
"--vae_type",
|
||||
default="wan",
|
||||
type=str,
|
||||
choices=["wan", "ltx2"],
|
||||
help="VAE type to use for saving full pipeline (ltx2 uses patchify 1x1x1).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=False,
|
||||
help="Optional VAE path or repo id. If not set, a default is used per VAE type.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task", default="t2v", type=str, required=True, choices=["t2v", "i2v"], help="Task to convert, t2v or i2v."
|
||||
)
|
||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
|
||||
parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipeline elements in one.")
|
||||
parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import functools
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Type
|
||||
@@ -32,7 +31,7 @@ from ..models._modeling_parallel import (
|
||||
gather_size_by_comm,
|
||||
)
|
||||
from ..utils import get_logger
|
||||
from ..utils.torch_utils import maybe_allow_in_graph, unwrap_module
|
||||
from ..utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph, unwrap_module
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
|
||||
@@ -327,7 +326,7 @@ class PartitionAnythingSharder:
|
||||
return tensor
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=64)
|
||||
@lru_cache_unless_export(maxsize=64)
|
||||
def _fill_gather_shapes(shape: tuple[int], gather_dims: tuple[int], dim: int, world_size: int) -> list[list[int]]:
|
||||
gather_shapes = []
|
||||
for i in range(world_size):
|
||||
|
||||
@@ -49,7 +49,7 @@ from ..utils import (
|
||||
is_xformers_version,
|
||||
)
|
||||
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
|
||||
from ..utils.torch_utils import maybe_allow_in_graph
|
||||
from ..utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph
|
||||
from ._modeling_parallel import gather_size_by_comm
|
||||
|
||||
|
||||
@@ -575,7 +575,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
||||
)
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
@lru_cache_unless_export(maxsize=128)
|
||||
def _prepare_for_flash_attn_or_sage_varlen_without_mask(
|
||||
batch_size: int,
|
||||
seq_len_q: int,
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
@@ -343,7 +342,6 @@ class HeliosRotaryPosEmbed(nn.Module):
|
||||
return freqs.cos(), freqs.sin()
|
||||
|
||||
@torch.no_grad()
|
||||
@lru_cache(maxsize=32)
|
||||
def _get_spatial_meshgrid(self, height, width, device_str):
|
||||
device = torch.device(device_str)
|
||||
grid_y_coords = torch.arange(height, device=device, dtype=torch.float32)
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import math
|
||||
from math import prod
|
||||
from typing import Any
|
||||
@@ -25,7 +24,7 @@ import torch.nn.functional as F
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, deprecate, logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ...utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
@@ -307,7 +306,7 @@ class QwenEmbedRope(nn.Module):
|
||||
|
||||
return vid_freqs, txt_freqs
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
@lru_cache_unless_export(maxsize=128)
|
||||
def _compute_video_freqs(
|
||||
self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None
|
||||
) -> torch.Tensor:
|
||||
@@ -428,7 +427,7 @@ class QwenEmbedLayer3DRope(nn.Module):
|
||||
|
||||
return vid_freqs, txt_freqs
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
@lru_cache_unless_export(maxsize=None)
|
||||
def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None):
|
||||
seq_lens = frame * height * width
|
||||
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
|
||||
@@ -450,7 +449,7 @@ class QwenEmbedLayer3DRope(nn.Module):
|
||||
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
||||
return freqs.clone().contiguous()
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
@lru_cache_unless_export(maxsize=None)
|
||||
def _compute_condition_freqs(self, frame, height, width, device: torch.device = None):
|
||||
seq_lens = frame * height * width
|
||||
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
|
||||
|
||||
@@ -720,6 +720,7 @@ class LDMBertModel(LDMBertPreTrainedModel):
|
||||
super().__init__(config)
|
||||
self.model = LDMBertEncoder(config)
|
||||
self.to_logits = nn.Linear(config.hidden_size, config.vocab_size)
|
||||
self.post_init()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -35,6 +35,8 @@ class PaintByExampleImageEncoder(CLIPPreTrainedModel):
|
||||
# uncondition for scaling
|
||||
self.uncond_vector = nn.Parameter(torch.randn((1, 1, self.proj_size)))
|
||||
|
||||
self.post_init()
|
||||
|
||||
def forward(self, pixel_values, return_uncond_vector=False):
|
||||
clip_output = self.model(pixel_values=pixel_values)
|
||||
latent_states = clip_output.pooler_output
|
||||
|
||||
@@ -24,7 +24,7 @@ from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFa
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...loaders import SanaLoraLoaderMixin
|
||||
from ...models import AutoencoderDC, AutoencoderKLWan, SanaVideoTransformer3DModel
|
||||
from ...models import AutoencoderDC, AutoencoderKLLTX2Video, AutoencoderKLWan, SanaVideoTransformer3DModel
|
||||
from ...schedulers import DPMSolverMultistepScheduler
|
||||
from ...utils import (
|
||||
BACKENDS_MAPPING,
|
||||
@@ -194,7 +194,7 @@ class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
The tokenizer used to tokenize the prompt.
|
||||
text_encoder ([`Gemma2PreTrainedModel`]):
|
||||
Text encoder model to encode the input prompts.
|
||||
vae ([`AutoencoderKLWan` or `AutoencoderDCAEV`]):
|
||||
vae ([`AutoencoderKLWan`, `AutoencoderDC`, or `AutoencoderKLLTX2Video`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
transformer ([`SanaVideoTransformer3DModel`]):
|
||||
Conditional Transformer to denoise the input latents.
|
||||
@@ -213,7 +213,7 @@ class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
self,
|
||||
tokenizer: GemmaTokenizer | GemmaTokenizerFast,
|
||||
text_encoder: Gemma2PreTrainedModel,
|
||||
vae: AutoencoderDC | AutoencoderKLWan,
|
||||
vae: AutoencoderDC | AutoencoderKLLTX2Video | AutoencoderKLWan,
|
||||
transformer: SanaVideoTransformer3DModel,
|
||||
scheduler: DPMSolverMultistepScheduler,
|
||||
):
|
||||
@@ -223,8 +223,19 @@ class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
||||
)
|
||||
|
||||
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
|
||||
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
|
||||
if getattr(self, "vae", None):
|
||||
if isinstance(self.vae, AutoencoderKLLTX2Video):
|
||||
self.vae_scale_factor_temporal = self.vae.config.temporal_compression_ratio
|
||||
self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio
|
||||
elif isinstance(self.vae, (AutoencoderDC, AutoencoderKLWan)):
|
||||
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal
|
||||
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial
|
||||
else:
|
||||
self.vae_scale_factor_temporal = 4
|
||||
self.vae_scale_factor_spatial = 8
|
||||
else:
|
||||
self.vae_scale_factor_temporal = 4
|
||||
self.vae_scale_factor_spatial = 8
|
||||
|
||||
self.vae_scale_factor = self.vae_scale_factor_spatial
|
||||
|
||||
@@ -985,14 +996,21 @@ class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
if is_torch_version(">=", "2.5.0")
|
||||
else torch_accelerator_module.OutOfMemoryError
|
||||
)
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
if isinstance(self.vae, AutoencoderKLLTX2Video):
|
||||
latents_mean = self.vae.latents_mean
|
||||
latents_std = self.vae.latents_std
|
||||
z_dim = self.vae.config.latent_channels
|
||||
elif isinstance(self.vae, AutoencoderKLWan):
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean)
|
||||
latents_std = torch.tensor(self.vae.config.latents_std)
|
||||
z_dim = self.vae.config.z_dim
|
||||
else:
|
||||
latents_mean = torch.zeros(latents.shape[1], device=latents.device, dtype=latents.dtype)
|
||||
latents_std = torch.ones(latents.shape[1], device=latents.device, dtype=latents.dtype)
|
||||
z_dim = latents.shape[1]
|
||||
|
||||
latents_mean = latents_mean.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = 1.0 / latents_std.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = latents / latents_std + latents_mean
|
||||
try:
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
@@ -26,7 +26,7 @@ from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFa
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...loaders import SanaLoraLoaderMixin
|
||||
from ...models import AutoencoderDC, AutoencoderKLWan, SanaVideoTransformer3DModel
|
||||
from ...models import AutoencoderDC, AutoencoderKLLTX2Video, AutoencoderKLWan, SanaVideoTransformer3DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import (
|
||||
BACKENDS_MAPPING,
|
||||
@@ -184,7 +184,7 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
The tokenizer used to tokenize the prompt.
|
||||
text_encoder ([`Gemma2PreTrainedModel`]):
|
||||
Text encoder model to encode the input prompts.
|
||||
vae ([`AutoencoderKLWan` or `AutoencoderDCAEV`]):
|
||||
vae ([`AutoencoderKLWan`, `AutoencoderDC`, or `AutoencoderKLLTX2Video`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
transformer ([`SanaVideoTransformer3DModel`]):
|
||||
Conditional Transformer to denoise the input latents.
|
||||
@@ -203,7 +203,7 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
self,
|
||||
tokenizer: GemmaTokenizer | GemmaTokenizerFast,
|
||||
text_encoder: Gemma2PreTrainedModel,
|
||||
vae: AutoencoderDC | AutoencoderKLWan,
|
||||
vae: AutoencoderDC | AutoencoderKLLTX2Video | AutoencoderKLWan,
|
||||
transformer: SanaVideoTransformer3DModel,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
):
|
||||
@@ -213,8 +213,19 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
||||
)
|
||||
|
||||
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
|
||||
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
|
||||
if getattr(self, "vae", None):
|
||||
if isinstance(self.vae, AutoencoderKLLTX2Video):
|
||||
self.vae_scale_factor_temporal = self.vae.config.temporal_compression_ratio
|
||||
self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio
|
||||
elif isinstance(self.vae, (AutoencoderDC, AutoencoderKLWan)):
|
||||
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal
|
||||
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial
|
||||
else:
|
||||
self.vae_scale_factor_temporal = 4
|
||||
self.vae_scale_factor_spatial = 8
|
||||
else:
|
||||
self.vae_scale_factor_temporal = 4
|
||||
self.vae_scale_factor_spatial = 8
|
||||
|
||||
self.vae_scale_factor = self.vae_scale_factor_spatial
|
||||
|
||||
@@ -687,14 +698,18 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax")
|
||||
image_latents = image_latents.repeat(batch_size, 1, 1, 1, 1)
|
||||
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, -1, 1, 1, 1)
|
||||
.to(image_latents.device, image_latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to(
|
||||
image_latents.device, image_latents.dtype
|
||||
)
|
||||
if isinstance(self.vae, AutoencoderKLLTX2Video):
|
||||
_latents_mean = self.vae.latents_mean
|
||||
_latents_std = self.vae.latents_std
|
||||
elif isinstance(self.vae, AutoencoderKLWan):
|
||||
_latents_mean = torch.tensor(self.vae.config.latents_mean)
|
||||
_latents_std = torch.tensor(self.vae.config.latents_std)
|
||||
else:
|
||||
_latents_mean = torch.zeros(image_latents.shape[1], device=image_latents.device, dtype=image_latents.dtype)
|
||||
_latents_std = torch.ones(image_latents.shape[1], device=image_latents.device, dtype=image_latents.dtype)
|
||||
|
||||
latents_mean = _latents_mean.view(1, -1, 1, 1, 1).to(image_latents.device, image_latents.dtype)
|
||||
latents_std = 1.0 / _latents_std.view(1, -1, 1, 1, 1).to(image_latents.device, image_latents.dtype)
|
||||
image_latents = (image_latents - latents_mean) * latents_std
|
||||
|
||||
latents[:, :, 0:1] = image_latents.to(dtype)
|
||||
@@ -1034,14 +1049,21 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
if is_torch_version(">=", "2.5.0")
|
||||
else torch_accelerator_module.OutOfMemoryError
|
||||
)
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
if isinstance(self.vae, AutoencoderKLLTX2Video):
|
||||
latents_mean = self.vae.latents_mean
|
||||
latents_std = self.vae.latents_std
|
||||
z_dim = self.vae.config.latent_channels
|
||||
elif isinstance(self.vae, AutoencoderKLWan):
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean)
|
||||
latents_std = torch.tensor(self.vae.config.latents_std)
|
||||
z_dim = self.vae.config.z_dim
|
||||
else:
|
||||
latents_mean = torch.zeros(latents.shape[1], device=latents.device, dtype=latents.dtype)
|
||||
latents_std = torch.ones(latents.shape[1], device=latents.device, dtype=latents.dtype)
|
||||
z_dim = latents.shape[1]
|
||||
|
||||
latents_mean = latents_mean.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = 1.0 / latents_std.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = latents / latents_std + latents_mean
|
||||
try:
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
@@ -19,11 +19,16 @@ from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import os
|
||||
from typing import Callable, ParamSpec, TypeVar
|
||||
|
||||
from . import logging
|
||||
from .import_utils import is_torch_available, is_torch_mlu_available, is_torch_npu_available, is_torch_version
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch.fft import fftn, fftshift, ifftn, ifftshift
|
||||
@@ -333,5 +338,21 @@ def disable_full_determinism():
|
||||
torch.use_deterministic_algorithms(False)
|
||||
|
||||
|
||||
@functools.wraps(functools.lru_cache)
|
||||
def lru_cache_unless_export(maxsize=128, typed=False):
|
||||
def outer_wrapper(fn: Callable[P, T]):
|
||||
cached = functools.lru_cache(maxsize=maxsize, typed=typed)(fn)
|
||||
|
||||
@functools.wraps(fn)
|
||||
def inner_wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||
if torch.compiler.is_exporting():
|
||||
return fn(*args, **kwargs)
|
||||
return cached(*args, **kwargs)
|
||||
|
||||
return inner_wrapper
|
||||
|
||||
return outer_wrapper
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
torch_device = get_device()
|
||||
|
||||
@@ -139,9 +139,9 @@ class HeliosPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
|
||||
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
|
||||
|
||||
# Override to set a more lenient max diff threshold.
|
||||
@unittest.skip("Helios uses a lot of mixed precision internally, which is not suitable for this test case")
|
||||
def test_save_load_float16(self):
|
||||
super().test_save_load_float16(expected_max_diff=0.03)
|
||||
pass
|
||||
|
||||
@unittest.skip("Test not supported")
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
|
||||
Reference in New Issue
Block a user