Merge branch 'main' into wan-layerwise-upcasting-tests

This commit is contained in:
Sayak Paul
2026-01-29 11:59:57 +05:30
committed by GitHub
15 changed files with 620 additions and 109 deletions

View File

@@ -24,6 +24,179 @@ You can find all the original LTX-Video checkpoints under the [Lightricks](https
The original codebase for LTX-2 can be found [here](https://github.com/Lightricks/LTX-2).
## Two-stages Generation
Recommended pipeline to achieve production quality generation, this pipeline is composed of two stages:
- Stage 1: Generate a video at the target resolution using diffusion sampling with classifier-free guidance (CFG). This stage produces a coherent low-noise video sequence that respects the text/image conditioning.
- Stage 2: Upsample the Stage 1 output by 2 and refine details using a distilled LoRA model to improve fidelity and visual quality. Stage 2 may apply lighter CFG to preserve the structure from Stage 1 while enhancing texture and sharpness.
Sample usage of text-to-video two stages pipeline
```py
import torch
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.pipelines.ltx2 import LTX2Pipeline, LTX2LatentUpsamplePipeline
from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
from diffusers.pipelines.ltx2.utils import STAGE_2_DISTILLED_SIGMA_VALUES
from diffusers.pipelines.ltx2.export_utils import encode_video
device = "cuda:0"
width = 768
height = 512
pipe = LTX2Pipeline.from_pretrained(
"Lightricks/LTX-2", torch_dtype=torch.bfloat16
)
pipe.enable_sequential_cpu_offload(device=device)
prompt = "A beautiful sunset over the ocean"
negative_prompt = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static."
# Stage 1 default (non-distilled) inference
frame_rate = 24.0
video_latent, audio_latent = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
num_frames=121,
frame_rate=frame_rate,
num_inference_steps=40,
sigmas=None,
guidance_scale=4.0,
output_type="latent",
return_dict=False,
)
latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(
"Lightricks/LTX-2",
subfolder="latent_upsampler",
torch_dtype=torch.bfloat16,
)
upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler)
upsample_pipe.enable_model_cpu_offload(device=device)
upscaled_video_latent = upsample_pipe(
latents=video_latent,
output_type="latent",
return_dict=False,
)[0]
# Load Stage 2 distilled LoRA
pipe.load_lora_weights(
"Lightricks/LTX-2", adapter_name="stage_2_distilled", weight_name="ltx-2-19b-distilled-lora-384.safetensors"
)
pipe.set_adapters("stage_2_distilled", 1.0)
# VAE tiling is usually necessary to avoid OOM error when VAE decoding
pipe.vae.enable_tiling()
# Change scheduler to use Stage 2 distilled sigmas as is
new_scheduler = FlowMatchEulerDiscreteScheduler.from_config(
pipe.scheduler.config, use_dynamic_shifting=False, shift_terminal=None
)
pipe.scheduler = new_scheduler
# Stage 2 inference with distilled LoRA and sigmas
video, audio = pipe(
latents=upscaled_video_latent,
audio_latents=audio_latent,
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=3,
noise_scale=STAGE_2_DISTILLED_SIGMA_VALUES[0], # renoise with first sigma value https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py#L218
sigmas=STAGE_2_DISTILLED_SIGMA_VALUES,
guidance_scale=1.0,
output_type="np",
return_dict=False,
)
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)
encode_video(
video[0],
fps=frame_rate,
audio=audio[0].float().cpu(),
audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
output_path="ltx2_lora_distilled_sample.mp4",
)
```
## Distilled checkpoint generation
Fastest two-stages generation pipeline using a distilled checkpoint.
```py
import torch
from diffusers.pipelines.ltx2 import LTX2Pipeline, LTX2LatentUpsamplePipeline
from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES
from diffusers.pipelines.ltx2.export_utils import encode_video
device = "cuda"
width = 768
height = 512
random_seed = 42
generator = torch.Generator(device).manual_seed(random_seed)
model_path = "rootonchair/LTX-2-19b-distilled"
pipe = LTX2Pipeline.from_pretrained(
model_path, torch_dtype=torch.bfloat16
)
pipe.enable_sequential_cpu_offload(device=device)
prompt = "A beautiful sunset over the ocean"
negative_prompt = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static."
frame_rate = 24.0
video_latent, audio_latent = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
num_frames=121,
frame_rate=frame_rate,
num_inference_steps=8,
sigmas=DISTILLED_SIGMA_VALUES,
guidance_scale=1.0,
generator=generator,
output_type="latent",
return_dict=False,
)
latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(
model_path,
subfolder="latent_upsampler",
torch_dtype=torch.bfloat16,
)
upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler)
upsample_pipe.enable_model_cpu_offload(device=device)
upscaled_video_latent = upsample_pipe(
latents=video_latent,
output_type="latent",
return_dict=False,
)[0]
video, audio = pipe(
latents=upscaled_video_latent,
audio_latents=audio_latent,
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=3,
noise_scale=STAGE_2_DISTILLED_SIGMA_VALUES[0], # renoise with first sigma value https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-pipelines/src/ltx_pipelines/distilled.py#L178
sigmas=STAGE_2_DISTILLED_SIGMA_VALUES,
generator=generator,
guidance_scale=1.0,
output_type="np",
return_dict=False,
)
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)
encode_video(
video[0],
fps=frame_rate,
audio=audio[0].float().cpu(),
audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
output_path="ltx2_distilled_sample.mp4",
)
```
## LTX2Pipeline
[[autodoc]] LTX2Pipeline

View File

@@ -83,25 +83,6 @@ Refer to this [table](https://github.com/huggingface/diffusers/pull/10009#issue-
> [!TIP]
> The FP8 post-training quantization schemes in torchao are effective for GPUs with compute capability of at least 8.9 (RTX-4090, Hopper, etc.). FP8 often provides the best speed, memory, and quality trade-off when generating images and videos. We recommend combining FP8 and torch.compile if your GPU is compatible.
## autoquant
torchao provides [autoquant](https://docs.pytorch.org/ao/stable/generated/torchao.quantization.autoquant.html#torchao.quantization.autoquant) an automatic quantization API. Autoquantization chooses the best quantization strategy by comparing the performance of each strategy on chosen input types and shapes. This is only supported in Diffusers for individual models at the moment.
```py
import torch
from diffusers import DiffusionPipeline
from torchao.quantization import autoquant
# Load the pipeline
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16,
device_map="cuda"
)
transformer = autoquant(pipeline.transformer)
```
## Supported quantization types
torchao supports weight-only quantization and weight and dynamic-activation quantization for int8, float3-float8, and uint1-uint7.

View File

@@ -63,6 +63,8 @@ LTX_2_0_VIDEO_VAE_RENAME_DICT = {
"up_blocks.4": "up_blocks.1",
"up_blocks.5": "up_blocks.2.upsamplers.0",
"up_blocks.6": "up_blocks.2",
"last_time_embedder": "time_embedder",
"last_scale_shift_table": "scale_shift_table",
# Common
# For all 3D ResNets
"res_blocks": "resnets",
@@ -372,7 +374,9 @@ def convert_ltx2_connectors(original_state_dict: Dict[str, Any], version: str) -
return connectors
def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
def get_ltx2_video_vae_config(
version: str, timestep_conditioning: bool = False
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "test":
config = {
"model_id": "diffusers-internal-dev/dummy-ltx2",
@@ -396,7 +400,7 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": False,
"timestep_conditioning": timestep_conditioning,
"patch_size": 4,
"patch_size_t": 1,
"resnet_norm_eps": 1e-6,
@@ -433,7 +437,7 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": False,
"timestep_conditioning": timestep_conditioning,
"patch_size": 4,
"patch_size_t": 1,
"resnet_norm_eps": 1e-6,
@@ -450,8 +454,10 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A
return config, rename_dict, special_keys_remap
def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version)
def convert_ltx2_video_vae(
original_state_dict: Dict[str, Any], version: str, timestep_conditioning: bool
) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version, timestep_conditioning)
diffusers_config = config["diffusers_config"]
with init_empty_weights():
@@ -659,10 +665,15 @@ def get_model_state_dict_from_combined_ckpt(combined_ckpt: Dict[str, Any], prefi
def get_args():
parser = argparse.ArgumentParser()
def none_or_str(value: str):
if isinstance(value, str) and value.lower() == "none":
return None
return value
parser.add_argument(
"--original_state_dict_repo_id",
default="Lightricks/LTX-2",
type=str,
type=none_or_str,
help="HF Hub repo id with LTX 2.0 checkpoint",
)
parser.add_argument(
@@ -682,7 +693,7 @@ def get_args():
parser.add_argument(
"--combined_filename",
default="ltx-2-19b-dev.safetensors",
type=str,
type=none_or_str,
help="Filename for combined checkpoint with all LTX 2.0 models (VAE, DiT, etc.)",
)
parser.add_argument("--vae_prefix", default="vae.", type=str)
@@ -701,22 +712,25 @@ def get_args():
parser.add_argument(
"--text_encoder_model_id",
default="google/gemma-3-12b-it-qat-q4_0-unquantized",
type=str,
type=none_or_str,
help="HF Hub id for the LTX 2.0 base text encoder model",
)
parser.add_argument(
"--tokenizer_id",
default="google/gemma-3-12b-it-qat-q4_0-unquantized",
type=str,
type=none_or_str,
help="HF Hub id for the LTX 2.0 text tokenizer",
)
parser.add_argument(
"--latent_upsampler_filename",
default="ltx-2-spatial-upscaler-x2-1.0.safetensors",
type=str,
type=none_or_str,
help="Latent upsampler filename",
)
parser.add_argument(
"--timestep_conditioning", action="store_true", help="Whether to add timestep condition to the video VAE model"
)
parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model")
parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model")
parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model")
@@ -786,7 +800,9 @@ def main(args):
original_vae_ckpt = load_hub_or_local_checkpoint(filename=args.vae_filename)
elif combined_ckpt is not None:
original_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vae_prefix)
vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version)
vae = convert_ltx2_video_vae(
original_vae_ckpt, version=args.version, timestep_conditioning=args.timestep_conditioning
)
if not args.full_pipeline and not args.upsample_pipeline:
vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae"))

View File

@@ -743,8 +743,8 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin):
# Per-channel statistics for normalizing and denormalizing the latent representation. This statics is computed over
# the entire dataset and stored in model's checkpoint under AudioVAE state_dict
latents_std = torch.zeros((base_channels,))
latents_mean = torch.ones((base_channels,))
latents_std = torch.ones((base_channels,))
latents_mean = torch.zeros((base_channels,))
self.register_buffer("latents_mean", latents_mean, persistent=True)
self.register_buffer("latents_std", latents_std, persistent=True)

View File

@@ -584,6 +584,17 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
return latents
@staticmethod
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2_image2video.LTX2ImageToVideoPipeline._normalize_latents
def _normalize_latents(
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
) -> torch.Tensor:
# Normalize latents across the channel dimension [B, C, F, H, W]
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents = (latents - latents_mean) * scaling_factor / latents_std
return latents
@staticmethod
def _denormalize_latents(
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
@@ -594,12 +605,26 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
latents = latents * latents_std / scaling_factor + latents_mean
return latents
@staticmethod
def _normalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor):
latents_mean = latents_mean.to(latents.device, latents.dtype)
latents_std = latents_std.to(latents.device, latents.dtype)
return (latents - latents_mean) / latents_std
@staticmethod
def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor):
latents_mean = latents_mean.to(latents.device, latents.dtype)
latents_std = latents_std.to(latents.device, latents.dtype)
return (latents * latents_std) + latents_mean
@staticmethod
def _create_noised_state(
latents: torch.Tensor, noise_scale: Union[float, torch.Tensor], generator: Optional[torch.Generator] = None
):
noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype)
noised_latents = noise_scale * noise + (1 - noise_scale) * latents
return noised_latents
@staticmethod
def _pack_audio_latents(
latents: torch.Tensor, patch_size: Optional[int] = None, patch_size_t: Optional[int] = None
@@ -647,12 +672,26 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
height: int = 512,
width: int = 768,
num_frames: int = 121,
noise_scale: float = 0.0,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if latents is not None:
if latents.ndim == 5:
latents = self._normalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
# latents are of shape [B, C, F, H, W], need to be packed
latents = self._pack_latents(
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
)
if latents.ndim != 3:
raise ValueError(
f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is [batch_size, num_seq, num_features]."
)
latents = self._create_noised_state(latents, noise_scale, generator)
return latents.to(device=device, dtype=dtype)
height = height // self.vae_spatial_compression_ratio
@@ -677,29 +716,30 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
self,
batch_size: int = 1,
num_channels_latents: int = 8,
audio_latent_length: int = 1, # 1 is just a dummy value
num_mel_bins: int = 64,
num_frames: int = 121,
frame_rate: float = 25.0,
sampling_rate: int = 16000,
hop_length: int = 160,
noise_scale: float = 0.0,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
duration_s = num_frames / frame_rate
latents_per_second = (
float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio)
)
latent_length = round(duration_s * latents_per_second)
if latents is not None:
return latents.to(device=device, dtype=dtype), latent_length
if latents.ndim == 4:
# latents are of shape [B, C, L, M], need to be packed
latents = self._pack_audio_latents(latents)
if latents.ndim != 3:
raise ValueError(
f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is [batch_size, num_seq, num_features]."
)
latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std)
latents = self._create_noised_state(latents, noise_scale, generator)
return latents.to(device=device, dtype=dtype)
# TODO: confirm whether this logic is correct
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
shape = (batch_size, num_channels_latents, latent_length, latent_mel_bins)
shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
@@ -709,7 +749,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self._pack_audio_latents(latents)
return latents, latent_length
return latents
@property
def guidance_scale(self):
@@ -750,9 +790,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
num_frames: int = 121,
frame_rate: float = 24.0,
num_inference_steps: int = 40,
sigmas: Optional[List[float]] = None,
timesteps: List[int] = None,
guidance_scale: float = 4.0,
guidance_rescale: float = 0.0,
noise_scale: float = 0.0,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
@@ -788,6 +830,10 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
num_inference_steps (`int`, *optional*, defaults to 40):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
@@ -804,6 +850,9 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
[Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
using zero terminal SNR.
noise_scale (`float`, *optional*, defaults to `0.0`):
The interpolation factor between random noise and denoised latents at each timestep. Applying noise to
the `latents` and `audio_latents` before continue denoising.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -922,6 +971,21 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
latent_height = height // self.vae_spatial_compression_ratio
latent_width = width // self.vae_spatial_compression_ratio
if latents is not None:
if latents.ndim == 5:
logger.info(
"Got latents of shape [batch_size, latent_dim, latent_frames, latent_height, latent_width], `latent_num_frames`, `latent_height`, `latent_width` will be inferred."
)
_, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W]
elif latents.ndim == 3:
logger.warning(
f"You have supplied packed `latents` of shape {latents.shape}, so the latent dims cannot be"
f" inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct."
)
else:
raise ValueError(
f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, latent_dim, latent_frames, latent_height, latent_width]."
)
video_sequence_length = latent_num_frames * latent_height * latent_width
num_channels_latents = self.transformer.config.in_channels
@@ -931,26 +995,45 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
height,
width,
num_frames,
noise_scale,
torch.float32,
device,
generator,
latents,
)
duration_s = num_frames / frame_rate
audio_latents_per_second = (
self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio)
)
audio_num_frames = round(duration_s * audio_latents_per_second)
if audio_latents is not None:
if audio_latents.ndim == 4:
logger.info(
"Got audio_latents of shape [batch_size, num_channels, audio_length, mel_bins], `audio_num_frames` will be inferred."
)
_, _, audio_num_frames, _ = audio_latents.shape # [B, C, L, M]
elif audio_latents.ndim == 3:
logger.warning(
f"You have supplied packed `audio_latents` of shape {audio_latents.shape}, so the latent dims"
f" cannot be inferred. Make sure the supplied `num_frames` and `frame_rate` are correct."
)
else:
raise ValueError(
f"Provided `audio_latents` tensor has shape {audio_latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, num_channels, audio_length, mel_bins]."
)
num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
num_channels_latents_audio = (
self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8
)
audio_latents, audio_num_frames = self.prepare_audio_latents(
audio_latents = self.prepare_audio_latents(
batch_size * num_videos_per_prompt,
num_channels_latents=num_channels_latents_audio,
audio_latent_length=audio_num_frames,
num_mel_bins=num_mel_bins,
num_frames=num_frames, # Video frames, audio frames will be calculated from this
frame_rate=frame_rate,
sampling_rate=self.audio_sampling_rate,
hop_length=self.audio_hop_length,
noise_scale=noise_scale,
dtype=torch.float32,
device=device,
generator=generator,
@@ -958,7 +1041,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
)
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
mu = calculate_shift(
video_sequence_length,
self.scheduler.config.get("base_image_seq_len", 1024),

View File

@@ -614,6 +614,15 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
latents = latents * latents_std / scaling_factor + latents_mean
return latents
@staticmethod
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._create_noised_state
def _create_noised_state(
latents: torch.Tensor, noise_scale: Union[float, torch.Tensor], generator: Optional[torch.Generator] = None
):
noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype)
noised_latents = noise_scale * noise + (1 - noise_scale) * latents
return noised_latents
@staticmethod
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_audio_latents
def _pack_audio_latents(
@@ -656,6 +665,13 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
latents = latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2)
return latents
@staticmethod
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._normalize_audio_latents
def _normalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor):
latents_mean = latents_mean.to(latents.device, latents.dtype)
latents_std = latents_std.to(latents.device, latents.dtype)
return (latents - latents_mean) / latents_std
@staticmethod
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_audio_latents
def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor):
@@ -671,6 +687,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
height: int = 512,
width: int = 704,
num_frames: int = 161,
noise_scale: float = 0.0,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[torch.Generator] = None,
@@ -686,6 +703,15 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
if latents is not None:
conditioning_mask = latents.new_zeros(mask_shape)
conditioning_mask[:, :, 0] = 1.0
if latents.ndim == 5:
latents = self._normalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
latents = self._create_noised_state(latents, noise_scale * (1 - conditioning_mask), generator)
# latents are of shape [B, C, F, H, W], need to be packed
latents = self._pack_latents(
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
)
conditioning_mask = self._pack_latents(
conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
).squeeze(-1)
@@ -737,29 +763,30 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
self,
batch_size: int = 1,
num_channels_latents: int = 8,
audio_latent_length: int = 1, # 1 is just a dummy value
num_mel_bins: int = 64,
num_frames: int = 121,
frame_rate: float = 25.0,
sampling_rate: int = 16000,
hop_length: int = 160,
noise_scale: float = 0.0,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
duration_s = num_frames / frame_rate
latents_per_second = (
float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio)
)
latent_length = round(duration_s * latents_per_second)
if latents is not None:
return latents.to(device=device, dtype=dtype), latent_length
if latents.ndim == 4:
# latents are of shape [B, C, L, M], need to be packed
latents = self._pack_audio_latents(latents)
if latents.ndim != 3:
raise ValueError(
f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is [batch_size, num_seq, num_features]."
)
latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std)
latents = self._create_noised_state(latents, noise_scale, generator)
return latents.to(device=device, dtype=dtype)
# TODO: confirm whether this logic is correct
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
shape = (batch_size, num_channels_latents, latent_length, latent_mel_bins)
shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
@@ -769,7 +796,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self._pack_audio_latents(latents)
return latents, latent_length
return latents
@property
def guidance_scale(self):
@@ -811,9 +838,11 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
num_frames: int = 121,
frame_rate: float = 24.0,
num_inference_steps: int = 40,
sigmas: Optional[List[float]] = None,
timesteps: List[int] = None,
guidance_scale: float = 4.0,
guidance_rescale: float = 0.0,
noise_scale: float = 0.0,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
@@ -851,6 +880,10 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
num_inference_steps (`int`, *optional*, defaults to 40):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
@@ -867,6 +900,9 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
[Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
using zero terminal SNR.
noise_scale (`float`, *optional*, defaults to `0.0`):
The interpolation factor between random noise and denoised latents at each timestep. Applying noise to
the `latents` and `audio_latents` before continue denoising.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -982,6 +1018,26 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
)
# 4. Prepare latent variables
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
latent_height = height // self.vae_spatial_compression_ratio
latent_width = width // self.vae_spatial_compression_ratio
if latents is not None:
if latents.ndim == 5:
logger.info(
"Got latents of shape [batch_size, latent_dim, latent_frames, latent_height, latent_width], `latent_num_frames`, `latent_height`, `latent_width` will be inferred."
)
_, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W]
elif latents.ndim == 3:
logger.warning(
f"You have supplied packed `latents` of shape {latents.shape}, so the latent dims cannot be"
f" inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct."
)
else:
raise ValueError(
f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, latent_dim, latent_frames, latent_height, latent_width]."
)
video_sequence_length = latent_num_frames * latent_height * latent_width
if latents is None:
image = self.video_processor.preprocess(image, height=height, width=width)
image = image.to(device=device, dtype=prompt_embeds.dtype)
@@ -994,6 +1050,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
height,
width,
num_frames,
noise_scale,
torch.float32,
device,
generator,
@@ -1002,20 +1059,38 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
if self.do_classifier_free_guidance:
conditioning_mask = torch.cat([conditioning_mask, conditioning_mask])
duration_s = num_frames / frame_rate
audio_latents_per_second = (
self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio)
)
audio_num_frames = round(duration_s * audio_latents_per_second)
if audio_latents is not None:
if audio_latents.ndim == 4:
logger.info(
"Got audio_latents of shape [batch_size, num_channels, audio_length, mel_bins], `audio_num_frames` will be inferred."
)
_, _, audio_num_frames, _ = audio_latents.shape # [B, C, L, M]
elif audio_latents.ndim == 3:
logger.warning(
f"You have supplied packed `audio_latents` of shape {audio_latents.shape}, so the latent dims"
f" cannot be inferred. Make sure the supplied `num_frames` and `frame_rate` are correct."
)
else:
raise ValueError(
f"Provided `audio_latents` tensor has shape {audio_latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, num_channels, audio_length, mel_bins]."
)
num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
num_channels_latents_audio = (
self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8
)
audio_latents, audio_num_frames = self.prepare_audio_latents(
audio_latents = self.prepare_audio_latents(
batch_size * num_videos_per_prompt,
num_channels_latents=num_channels_latents_audio,
audio_latent_length=audio_num_frames,
num_mel_bins=num_mel_bins,
num_frames=num_frames, # Video frames, audio frames will be calculated from this
frame_rate=frame_rate,
sampling_rate=self.audio_sampling_rate,
hop_length=self.audio_hop_length,
noise_scale=noise_scale,
dtype=torch.float32,
device=device,
generator=generator,
@@ -1023,12 +1098,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
)
# 5. Prepare timesteps
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
latent_height = height // self.vae_spatial_compression_ratio
latent_width = width // self.vae_spatial_compression_ratio
video_sequence_length = latent_num_frames * latent_height * latent_width
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
mu = calculate_shift(
video_sequence_length,
self.scheduler.config.get("base_image_seq_len", 1024),

View File

@@ -228,17 +228,6 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline):
filtered = latents * scales
return filtered
@staticmethod
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2_image2video.LTX2ImageToVideoPipeline._normalize_latents
def _normalize_latents(
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
) -> torch.Tensor:
# Normalize latents across the channel dimension [B, C, F, H, W]
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents = (latents - latents_mean) * scaling_factor / latents_std
return latents
@staticmethod
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_latents
def _denormalize_latents(
@@ -408,9 +397,6 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline):
latents = self.tone_map_latents(latents, tone_map_compression_ratio)
if output_type == "latent":
latents = self._normalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
video = latents
else:
if not self.vae.config.timestep_conditioning:

View File

@@ -0,0 +1,6 @@
# Pre-trained sigma values for distilled model are taken from
# https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-pipelines/src/ltx_pipelines/utils/constants.py
DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875]
# Reduced schedule for super-resolution stage 2 (subset of distilled values)
STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875]

View File

@@ -496,6 +496,17 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
num_frames = max(num_frames, 1)
h_multiple_of = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
w_multiple_of = self.vae_scale_factor_spatial * self.transformer.config.patch_size[2]
calc_height = height // h_multiple_of * h_multiple_of
calc_width = width // w_multiple_of * w_multiple_of
if height != calc_height or width != calc_width:
logger.warning(
f"`height` and `width` must be multiples of ({h_multiple_of}, {w_multiple_of}) for proper patchification. "
f"Adjusting ({height}, {width}) -> ({calc_height}, {calc_width})."
)
height, width = calc_height, calc_width
if self.config.boundary_ratio is not None and guidance_scale_2 is None:
guidance_scale_2 = guidance_scale

View File

@@ -637,6 +637,17 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
num_frames = max(num_frames, 1)
h_multiple_of = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
w_multiple_of = self.vae_scale_factor_spatial * self.transformer.config.patch_size[2]
calc_height = height // h_multiple_of * h_multiple_of
calc_width = width // w_multiple_of * w_multiple_of
if height != calc_height or width != calc_width:
logger.warning(
f"`height` and `width` must be multiples of ({h_multiple_of}, {w_multiple_of}) for proper patchification. "
f"Adjusting ({height}, {width}) -> ({calc_height}, {calc_width})."
)
height, width = calc_height, calc_width
if self.config.boundary_ratio is not None and guidance_scale_2 is None:
guidance_scale_2 = guidance_scale

View File

@@ -623,7 +623,7 @@ class TorchAoConfig(QuantizationConfigMixin):
"""
if is_torchao_available():
# TODO(aryan): Support autoquant and sparsify
# TODO(aryan): Support sparsify
from torchao.quantization import (
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,

View File

@@ -344,7 +344,6 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
from torchao.core.config import AOBaseConfig
quant_type = self.quantization_config.quant_type
# For autoquant case, it will be treated in the string implementation below in map_to_target_dtype
if isinstance(quant_type, AOBaseConfig):
# Extract size digit using fuzzy match on the class name
config_name = quant_type.__class__.__name__

View File

@@ -105,7 +105,6 @@ def rescale_zero_terminal_snr(alphas_cumprod):
"""
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
Args:
betas (`torch.Tensor`):
the betas that the scheduler is being initialized with.
@@ -175,11 +174,14 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
timestep_spacing (`str`, defaults to `"leading"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. Choose from
`leading`, `linspace` or `trailing`.
rescale_betas_zero_snr (`bool`, defaults to `False`):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
dark samples instead of limiting it to samples with medium brightness. Loosely related to
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
snr_shift_scale (`float`, defaults to 3.0):
Shift scale for SNR.
"""
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -191,15 +193,15 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
num_train_timesteps: int = 1000,
beta_start: float = 0.00085,
beta_end: float = 0.0120,
beta_schedule: str = "scaled_linear",
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "scaled_linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
clip_sample: bool = True,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
prediction_type: str = "epsilon",
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
clip_sample_range: float = 1.0,
sample_max_value: float = 1.0,
timestep_spacing: str = "leading",
timestep_spacing: Literal["leading", "linspace", "trailing"] = "leading",
rescale_betas_zero_snr: bool = False,
snr_shift_scale: float = 3.0,
):
@@ -209,7 +211,15 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float64) ** 2
self.betas = (
torch.linspace(
beta_start**0.5,
beta_end**0.5,
num_train_timesteps,
dtype=torch.float64,
)
** 2
)
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -266,13 +276,20 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
"""
return sample
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
def set_timesteps(
self,
num_inference_steps: int,
device: Optional[Union[str, torch.device]] = None,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None` (the default), the timesteps are not
moved.
"""
if num_inference_steps > self.config.num_train_timesteps:
@@ -311,7 +328,27 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = torch.from_numpy(timesteps).to(device)
def get_variables(self, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back=None):
def get_variables(
self,
alpha_prod_t: torch.Tensor,
alpha_prod_t_prev: torch.Tensor,
alpha_prod_t_back: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor]:
"""
Compute the variables used for DPM-Solver++ (2M) referencing the original implementation.
Args:
alpha_prod_t (`torch.Tensor`):
The cumulative product of alphas at the current timestep.
alpha_prod_t_prev (`torch.Tensor`):
The cumulative product of alphas at the previous timestep.
alpha_prod_t_back (`torch.Tensor`, *optional*):
The cumulative product of alphas at the timestep before the previous timestep.
Returns:
`tuple`:
A tuple containing the variables `h`, `r`, `lamb`, `lamb_next`.
"""
lamb = ((alpha_prod_t / (1 - alpha_prod_t)) ** 0.5).log()
lamb_next = ((alpha_prod_t_prev / (1 - alpha_prod_t_prev)) ** 0.5).log()
h = lamb_next - lamb
@@ -324,7 +361,36 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
else:
return h, None, lamb, lamb_next
def get_mult(self, h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back):
def get_mult(
self,
h: torch.Tensor,
r: Optional[torch.Tensor],
alpha_prod_t: torch.Tensor,
alpha_prod_t_prev: torch.Tensor,
alpha_prod_t_back: Optional[torch.Tensor] = None,
) -> Union[
Tuple[torch.Tensor, torch.Tensor],
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
]:
"""
Compute the multipliers for the previous sample and the predicted original sample.
Args:
h (`torch.Tensor`):
The log-SNR difference.
r (`torch.Tensor`):
The ratio of log-SNR differences.
alpha_prod_t (`torch.Tensor`):
The cumulative product of alphas at the current timestep.
alpha_prod_t_prev (`torch.Tensor`):
The cumulative product of alphas at the previous timestep.
alpha_prod_t_back (`torch.Tensor`, *optional*):
The cumulative product of alphas at the timestep before the previous timestep.
Returns:
`tuple`:
A tuple containing the multipliers.
"""
mult1 = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5 * (-h).exp()
mult2 = (-2 * h).expm1() * alpha_prod_t_prev**0.5
@@ -338,13 +404,13 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
old_pred_original_sample: torch.Tensor,
old_pred_original_sample: Optional[torch.Tensor],
timestep: int,
timestep_back: int,
sample: torch.Tensor,
eta: float = 0.0,
use_clipped_model_output: bool = False,
generator=None,
generator: Optional[torch.Generator] = None,
variance_noise: Optional[torch.Tensor] = None,
return_dict: bool = False,
) -> Union[DDIMSchedulerOutput, Tuple]:
@@ -355,8 +421,12 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`float`):
old_pred_original_sample (`torch.Tensor`):
The predicted original sample from the previous timestep.
timestep (`int`):
The current discrete timestep in the diffusion chain.
timestep_back (`int`):
The timestep to look back to.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
eta (`float`):
@@ -436,7 +506,12 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
return prev_sample, pred_original_sample
else:
denoised_d = mult[2] * pred_original_sample - mult[3] * old_pred_original_sample
noise = randn_tensor(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype)
noise = randn_tensor(
sample.shape,
generator=generator,
device=sample.device,
dtype=sample.dtype,
)
x_advanced = mult[0] * sample - mult[1] * denoised_d + mult_noise * noise
prev_sample = x_advanced
@@ -524,5 +599,5 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return velocity
def __len__(self):
def __len__(self) -> int:
return self.config.num_train_timesteps

View File

@@ -222,7 +222,57 @@ class LTX2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
)
expected_audio_slice = torch.tensor(
[
0.0236, 0.0499, 0.1230, 0.1094, 0.1713, 0.1044, 0.1729, 0.1009, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739
0.0263, 0.0528, 0.1217, 0.1104, 0.1632, 0.1072, 0.1789, 0.0949, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739
]
)
# fmt: on
video = video.flatten()
audio = audio.flatten()
generated_video_slice = torch.cat([video[:8], video[-8:]])
generated_audio_slice = torch.cat([audio[:8], audio[-8:]])
assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4)
assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4)
def test_two_stages_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["output_type"] = "latent"
first_stage_output = pipe(**inputs)
video_latent = first_stage_output.frames
audio_latent = first_stage_output.audio
self.assertEqual(video_latent.shape, (1, 4, 3, 16, 16))
self.assertEqual(audio_latent.shape, (1, 2, 5, 2))
self.assertEqual(audio_latent.shape[1], components["vocoder"].config.out_channels)
inputs["latents"] = video_latent
inputs["audio_latents"] = audio_latent
inputs["output_type"] = "pt"
second_stage_output = pipe(**inputs)
video = second_stage_output.frames
audio = second_stage_output.audio
self.assertEqual(video.shape, (1, 5, 3, 32, 32))
self.assertEqual(audio.shape[0], 1)
self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels)
# fmt: off
expected_video_slice = torch.tensor(
[
0.5514, 0.5943, 0.4260, 0.5971, 0.4306, 0.6369, 0.3124, 0.6964, 0.5419, 0.2412, 0.3882, 0.4504, 0.1941, 0.3404, 0.6037, 0.2464
]
)
expected_audio_slice = torch.tensor(
[
0.0252, 0.0526, 0.1211, 0.1119, 0.1638, 0.1042, 0.1776, 0.0948, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739
]
)
# fmt: on

View File

@@ -224,7 +224,57 @@ class LTX2ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
)
expected_audio_slice = torch.tensor(
[
0.0236, 0.0499, 0.1230, 0.1094, 0.1713, 0.1044, 0.1729, 0.1009, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739
0.0294, 0.0498, 0.1269, 0.1135, 0.1639, 0.1116, 0.1730, 0.0931, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739
]
)
# fmt: on
video = video.flatten()
audio = audio.flatten()
generated_video_slice = torch.cat([video[:8], video[-8:]])
generated_audio_slice = torch.cat([audio[:8], audio[-8:]])
assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4)
assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4)
def test_two_stages_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["output_type"] = "latent"
first_stage_output = pipe(**inputs)
video_latent = first_stage_output.frames
audio_latent = first_stage_output.audio
self.assertEqual(video_latent.shape, (1, 4, 3, 16, 16))
self.assertEqual(audio_latent.shape, (1, 2, 5, 2))
self.assertEqual(audio_latent.shape[1], components["vocoder"].config.out_channels)
inputs["latents"] = video_latent
inputs["audio_latents"] = audio_latent
inputs["output_type"] = "pt"
second_stage_output = pipe(**inputs)
video = second_stage_output.frames
audio = second_stage_output.audio
self.assertEqual(video.shape, (1, 5, 3, 32, 32))
self.assertEqual(audio.shape[0], 1)
self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels)
# fmt: off
expected_video_slice = torch.tensor(
[
0.2665, 0.6915, 0.2939, 0.6767, 0.2552, 0.6215, 0.1765, 0.6248, 0.2800, 0.2356, 0.3480, 0.5395, 0.3190, 0.4128, 0.4784, 0.4086
]
)
expected_audio_slice = torch.tensor(
[
0.0273, 0.0490, 0.1253, 0.1129, 0.1655, 0.1057, 0.1707, 0.0943, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739
]
)
# fmt: on