Compare commits

...

11 Commits

Author SHA1 Message Date
Daniel Gu
f4d47b9cec Infer latent dims if latents/audio_latents is supplied 2026-01-14 03:09:53 +01:00
Sayak Paul
ce5a51430b Merge branch 'main' into feat/distill-ltx2 2026-01-13 11:04:51 +05:30
Vinh H. Pham
9575e0632a Apply suggestions from code review
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
2026-01-13 12:01:03 +07:00
Pham Hong Vinh
faeccc557a Merge branch 'feat/distill-ltx2' of github.com:rootonchair/diffusers into feat/distill-ltx2 2026-01-12 11:31:16 +07:00
Pham Hong Vinh
96fbcd8301 fix copies and add pack latents to i2v 2026-01-12 11:30:31 +07:00
Pham Hong Vinh
837fd85c76 add sigma param to ltx2 i2v 2026-01-12 11:26:31 +07:00
Sayak Paul
d988fc34f1 Merge branch 'main' into feat/distill-ltx2 2026-01-12 09:16:50 +05:30
Pham Hong Vinh
82c2e7f068 remove prenorm 2026-01-11 23:01:34 +07:00
Pham Hong Vinh
6fbeacf53b make style & quality 2026-01-11 23:00:02 +07:00
Pham Hong Vinh
9c754a46aa add time conditioning conversion and token packing for latents 2026-01-11 22:13:58 +07:00
Pham Hong Vinh
3d78f9d17d add constants for distill sigmas values and allow ltx pipeline to pass in sigmas 2026-01-09 19:45:20 +07:00
4 changed files with 110 additions and 51 deletions

View File

@@ -63,6 +63,8 @@ LTX_2_0_VIDEO_VAE_RENAME_DICT = {
"up_blocks.4": "up_blocks.1",
"up_blocks.5": "up_blocks.2.upsamplers.0",
"up_blocks.6": "up_blocks.2",
"last_time_embedder": "time_embedder",
"last_scale_shift_table": "scale_shift_table",
# Common
# For all 3D ResNets
"res_blocks": "resnets",
@@ -372,7 +374,9 @@ def convert_ltx2_connectors(original_state_dict: Dict[str, Any], version: str) -
return connectors
def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
def get_ltx2_video_vae_config(
version: str, timestep_conditioning: bool = False
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "test":
config = {
"model_id": "diffusers-internal-dev/dummy-ltx2",
@@ -396,7 +400,7 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": False,
"timestep_conditioning": timestep_conditioning,
"patch_size": 4,
"patch_size_t": 1,
"resnet_norm_eps": 1e-6,
@@ -433,7 +437,7 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": False,
"timestep_conditioning": timestep_conditioning,
"patch_size": 4,
"patch_size_t": 1,
"resnet_norm_eps": 1e-6,
@@ -450,8 +454,10 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A
return config, rename_dict, special_keys_remap
def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version)
def convert_ltx2_video_vae(
original_state_dict: Dict[str, Any], version: str, timestep_conditioning: bool
) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version, timestep_conditioning)
diffusers_config = config["diffusers_config"]
with init_empty_weights():
@@ -717,6 +723,9 @@ def get_args():
help="Latent upsampler filename",
)
parser.add_argument(
"--timestep_conditioning", action="store_true", help="Whether to add timestep condition to the video VAE model"
)
parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model")
parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model")
parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model")
@@ -786,7 +795,9 @@ def main(args):
original_vae_ckpt = load_hub_or_local_checkpoint(filename=args.vae_filename)
elif combined_ckpt is not None:
original_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vae_prefix)
vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version)
vae = convert_ltx2_video_vae(
original_vae_ckpt, version=args.version, timestep_conditioning=args.timestep_conditioning
)
if not args.full_pipeline and not args.upsample_pipeline:
vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae"))

View File

@@ -653,6 +653,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if latents is not None:
if latents.ndim == 5:
# latents are of shape [B, C, F, H, W], need to be packed
latents = self._pack_latents(
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
)
return latents.to(device=device, dtype=dtype)
height = height // self.vae_spatial_compression_ratio
@@ -677,29 +682,23 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
self,
batch_size: int = 1,
num_channels_latents: int = 8,
audio_latent_length: int = 1, # 1 is just a dummy value
num_mel_bins: int = 64,
num_frames: int = 121,
frame_rate: float = 25.0,
sampling_rate: int = 16000,
hop_length: int = 160,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
duration_s = num_frames / frame_rate
latents_per_second = (
float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio)
)
latent_length = round(duration_s * latents_per_second)
if latents is not None:
return latents.to(device=device, dtype=dtype), latent_length
if latents.ndim == 4:
# latents are of shape [B, C, L, M], need to be packed
latents = self._pack_audio_latents(latents)
return latents.to(device=device, dtype=dtype)
# TODO: confirm whether this logic is correct
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
shape = (batch_size, num_channels_latents, latent_length, latent_mel_bins)
shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
@@ -709,7 +708,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self._pack_audio_latents(latents)
return latents, latent_length
return latents
@property
def guidance_scale(self):
@@ -750,6 +749,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
num_frames: int = 121,
frame_rate: float = 24.0,
num_inference_steps: int = 40,
sigmas: Optional[List[float]] = None,
timesteps: List[int] = None,
guidance_scale: float = 4.0,
guidance_rescale: float = 0.0,
@@ -788,6 +788,10 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
num_inference_steps (`int`, *optional*, defaults to 40):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
@@ -922,6 +926,14 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
latent_height = height // self.vae_spatial_compression_ratio
latent_width = width // self.vae_spatial_compression_ratio
if latents is not None:
if latents.ndim == 5:
_, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W]
else:
logger.warning(
f"You have supplied packed `latents` of shape {latents.shape}, so the latent dims cannot be"
f" inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct."
)
video_sequence_length = latent_num_frames * latent_height * latent_width
num_channels_latents = self.transformer.config.in_channels
@@ -937,20 +949,30 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
latents,
)
duration_s = num_frames / frame_rate
audio_latents_per_second = (
self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio)
)
audio_num_frames = round(duration_s * audio_latents_per_second)
if audio_latents is not None:
if audio_latents.ndim == 4:
_, _, audio_num_frames, _ = audio_latents.shape # [B, C, L, M]
else:
logger.warning(
f"You have supplied packed `audio_latents` of shape {audio_latents.shape}, so the latent dims"
f" cannot be inferred. Make sure the supplied `num_frames` is correct."
)
num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
num_channels_latents_audio = (
self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8
)
audio_latents, audio_num_frames = self.prepare_audio_latents(
audio_latents = self.prepare_audio_latents(
batch_size * num_videos_per_prompt,
num_channels_latents=num_channels_latents_audio,
audio_latent_length=audio_num_frames,
num_mel_bins=num_mel_bins,
num_frames=num_frames, # Video frames, audio frames will be calculated from this
frame_rate=frame_rate,
sampling_rate=self.audio_sampling_rate,
hop_length=self.audio_hop_length,
dtype=torch.float32,
device=device,
generator=generator,
@@ -958,7 +980,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
)
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
mu = calculate_shift(
video_sequence_length,
self.scheduler.config.get("base_image_seq_len", 1024),

View File

@@ -689,6 +689,11 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
conditioning_mask = self._pack_latents(
conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
).squeeze(-1)
if latents.ndim == 5:
# latents are of shape [B, C, F, H, W], need to be packed
latents = self._pack_latents(
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
)
if latents.ndim != 3 or latents.shape[:2] != conditioning_mask.shape:
raise ValueError(
f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is {conditioning_mask.shape + (num_channels_latents,)}."
@@ -737,29 +742,23 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
self,
batch_size: int = 1,
num_channels_latents: int = 8,
audio_latent_length: int = 1, # 1 is just a dummy value
num_mel_bins: int = 64,
num_frames: int = 121,
frame_rate: float = 25.0,
sampling_rate: int = 16000,
hop_length: int = 160,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
duration_s = num_frames / frame_rate
latents_per_second = (
float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio)
)
latent_length = round(duration_s * latents_per_second)
if latents is not None:
return latents.to(device=device, dtype=dtype), latent_length
if latents.ndim == 4:
# latents are of shape [B, C, L, M], need to be packed
latents = self._pack_audio_latents(latents)
return latents.to(device=device, dtype=dtype)
# TODO: confirm whether this logic is correct
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
shape = (batch_size, num_channels_latents, latent_length, latent_mel_bins)
shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
@@ -769,7 +768,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self._pack_audio_latents(latents)
return latents, latent_length
return latents
@property
def guidance_scale(self):
@@ -811,6 +810,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
num_frames: int = 121,
frame_rate: float = 24.0,
num_inference_steps: int = 40,
sigmas: Optional[List[float]] = None,
timesteps: List[int] = None,
guidance_scale: float = 4.0,
guidance_rescale: float = 0.0,
@@ -851,6 +851,10 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
num_inference_steps (`int`, *optional*, defaults to 40):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
@@ -982,6 +986,19 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
)
# 4. Prepare latent variables
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
latent_height = height // self.vae_spatial_compression_ratio
latent_width = width // self.vae_spatial_compression_ratio
if latents is not None:
if latents.ndim == 5:
_, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W]
else:
logger.warning(
f"You have supplied packed `latents` of shape {latents.shape}, so the latent dims cannot be"
f" inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct."
)
video_sequence_length = latent_num_frames * latent_height * latent_width
if latents is None:
image = self.video_processor.preprocess(image, height=height, width=width)
image = image.to(device=device, dtype=prompt_embeds.dtype)
@@ -1002,20 +1019,30 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
if self.do_classifier_free_guidance:
conditioning_mask = torch.cat([conditioning_mask, conditioning_mask])
duration_s = num_frames / frame_rate
audio_latents_per_second = (
self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio)
)
audio_num_frames = round(duration_s * audio_latents_per_second)
if audio_latents is not None:
if audio_latents.ndim == 4:
_, _, audio_num_frames, _ = audio_latents.shape # [B, C, L, M]
else:
logger.warning(
f"You have supplied packed `audio_latents` of shape {audio_latents.shape}, so the latent dims"
f" cannot be inferred. Make sure the supplied `num_frames` is correct."
)
num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
num_channels_latents_audio = (
self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8
)
audio_latents, audio_num_frames = self.prepare_audio_latents(
audio_latents = self.prepare_audio_latents(
batch_size * num_videos_per_prompt,
num_channels_latents=num_channels_latents_audio,
audio_latent_length=audio_num_frames,
num_mel_bins=num_mel_bins,
num_frames=num_frames, # Video frames, audio frames will be calculated from this
frame_rate=frame_rate,
sampling_rate=self.audio_sampling_rate,
hop_length=self.audio_hop_length,
dtype=torch.float32,
device=device,
generator=generator,
@@ -1023,12 +1050,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
)
# 5. Prepare timesteps
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
latent_height = height // self.vae_spatial_compression_ratio
latent_width = width // self.vae_spatial_compression_ratio
video_sequence_length = latent_num_frames * latent_height * latent_width
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
mu = calculate_shift(
video_sequence_length,
self.scheduler.config.get("base_image_seq_len", 1024),

View File

@@ -0,0 +1,4 @@
DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875]
# Reduced schedule for super-resolution stage 2 (subset of distilled values)
STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875]