From 74a0f0b694e1d9b3e4179a1b9e856cc49a888436 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Wed, 28 Jan 2026 10:40:22 -0500 Subject: [PATCH 1/4] remove torchao autoquant from diffusers docs (#13048) Summary: Context: https://github.com/pytorch/ao/issues/3739 Test Plan: CI, since this does not change any Python code --- docs/source/en/quantization/torchao.md | 19 ------------------- .../quantizers/quantization_config.py | 2 +- .../quantizers/torchao/torchao_quantizer.py | 1 - 3 files changed, 1 insertion(+), 21 deletions(-) diff --git a/docs/source/en/quantization/torchao.md b/docs/source/en/quantization/torchao.md index a501c4efc4..aa415dbc36 100644 --- a/docs/source/en/quantization/torchao.md +++ b/docs/source/en/quantization/torchao.md @@ -83,25 +83,6 @@ Refer to this [table](https://github.com/huggingface/diffusers/pull/10009#issue- > [!TIP] > The FP8 post-training quantization schemes in torchao are effective for GPUs with compute capability of at least 8.9 (RTX-4090, Hopper, etc.). FP8 often provides the best speed, memory, and quality trade-off when generating images and videos. We recommend combining FP8 and torch.compile if your GPU is compatible. -## autoquant - -torchao provides [autoquant](https://docs.pytorch.org/ao/stable/generated/torchao.quantization.autoquant.html#torchao.quantization.autoquant) an automatic quantization API. Autoquantization chooses the best quantization strategy by comparing the performance of each strategy on chosen input types and shapes. This is only supported in Diffusers for individual models at the moment. - -```py -import torch -from diffusers import DiffusionPipeline -from torchao.quantization import autoquant - -# Load the pipeline -pipeline = DiffusionPipeline.from_pretrained( - "black-forest-labs/FLUX.1-schnell", - torch_dtype=torch.bfloat16, - device_map="cuda" -) - -transformer = autoquant(pipeline.transformer) -``` - ## Supported quantization types torchao supports weight-only quantization and weight and dynamic-activation quantization for int8, float3-float8, and uint1-uint7. diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index cc5b5cc93d..c905a928c7 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -623,7 +623,7 @@ class TorchAoConfig(QuantizationConfigMixin): """ if is_torchao_available(): - # TODO(aryan): Support autoquant and sparsify + # TODO(aryan): Support sparsify from torchao.quantization import ( float8_dynamic_activation_float8_weight, float8_static_activation_float8_weight, diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index 0405afdaae..11435b85eb 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -344,7 +344,6 @@ class TorchAoHfQuantizer(DiffusersQuantizer): from torchao.core.config import AOBaseConfig quant_type = self.quantization_config.quant_type - # For autoquant case, it will be treated in the string implementation below in map_to_target_dtype if isinstance(quant_type, AOBaseConfig): # Extract size digit using fuzzy match on the class name config_name = quant_type.__class__.__name__ From 0ab2124958af12b5cf3f43e61833f52e88bd532e Mon Sep 17 00:00:00 2001 From: David El Malih Date: Wed, 28 Jan 2026 19:40:08 +0100 Subject: [PATCH 2/4] docs: improve docstring scheduling_dpm_cogvideox.py (#13044) --- .../schedulers/scheduling_dpm_cogvideox.py | 103 +++++++++++++++--- 1 file changed, 89 insertions(+), 14 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpm_cogvideox.py b/src/diffusers/schedulers/scheduling_dpm_cogvideox.py index 3e50ebbfe0..0c576467a1 100644 --- a/src/diffusers/schedulers/scheduling_dpm_cogvideox.py +++ b/src/diffusers/schedulers/scheduling_dpm_cogvideox.py @@ -105,7 +105,6 @@ def rescale_zero_terminal_snr(alphas_cumprod): """ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) - Args: betas (`torch.Tensor`): the betas that the scheduler is being initialized with. @@ -175,11 +174,14 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin): The threshold value for dynamic thresholding. Valid only when `thresholding=True`. timestep_spacing (`str`, defaults to `"leading"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. Choose from + `leading`, `linspace` or `trailing`. rescale_betas_zero_snr (`bool`, defaults to `False`): Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and dark samples instead of limiting it to samples with medium brightness. Loosely related to [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + snr_shift_scale (`float`, defaults to 3.0): + Shift scale for SNR. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -191,15 +193,15 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin): num_train_timesteps: int = 1000, beta_start: float = 0.00085, beta_end: float = 0.0120, - beta_schedule: str = "scaled_linear", + beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "scaled_linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, clip_sample: bool = True, set_alpha_to_one: bool = True, steps_offset: int = 0, - prediction_type: str = "epsilon", + prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon", clip_sample_range: float = 1.0, sample_max_value: float = 1.0, - timestep_spacing: str = "leading", + timestep_spacing: Literal["leading", "linspace", "trailing"] = "leading", rescale_betas_zero_snr: bool = False, snr_shift_scale: float = 3.0, ): @@ -209,7 +211,15 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin): self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. - self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float64) ** 2 + self.betas = ( + torch.linspace( + beta_start**0.5, + beta_end**0.5, + num_train_timesteps, + dtype=torch.float64, + ) + ** 2 + ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) @@ -266,13 +276,20 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin): """ return sample - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + def set_timesteps( + self, + num_inference_steps: int, + device: Optional[Union[str, torch.device]] = None, + ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None` (the default), the timesteps are not + moved. """ if num_inference_steps > self.config.num_train_timesteps: @@ -311,7 +328,27 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin): self.timesteps = torch.from_numpy(timesteps).to(device) - def get_variables(self, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back=None): + def get_variables( + self, + alpha_prod_t: torch.Tensor, + alpha_prod_t_prev: torch.Tensor, + alpha_prod_t_back: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor]: + """ + Compute the variables used for DPM-Solver++ (2M) referencing the original implementation. + + Args: + alpha_prod_t (`torch.Tensor`): + The cumulative product of alphas at the current timestep. + alpha_prod_t_prev (`torch.Tensor`): + The cumulative product of alphas at the previous timestep. + alpha_prod_t_back (`torch.Tensor`, *optional*): + The cumulative product of alphas at the timestep before the previous timestep. + + Returns: + `tuple`: + A tuple containing the variables `h`, `r`, `lamb`, `lamb_next`. + """ lamb = ((alpha_prod_t / (1 - alpha_prod_t)) ** 0.5).log() lamb_next = ((alpha_prod_t_prev / (1 - alpha_prod_t_prev)) ** 0.5).log() h = lamb_next - lamb @@ -324,7 +361,36 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin): else: return h, None, lamb, lamb_next - def get_mult(self, h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back): + def get_mult( + self, + h: torch.Tensor, + r: Optional[torch.Tensor], + alpha_prod_t: torch.Tensor, + alpha_prod_t_prev: torch.Tensor, + alpha_prod_t_back: Optional[torch.Tensor] = None, + ) -> Union[ + Tuple[torch.Tensor, torch.Tensor], + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], + ]: + """ + Compute the multipliers for the previous sample and the predicted original sample. + + Args: + h (`torch.Tensor`): + The log-SNR difference. + r (`torch.Tensor`): + The ratio of log-SNR differences. + alpha_prod_t (`torch.Tensor`): + The cumulative product of alphas at the current timestep. + alpha_prod_t_prev (`torch.Tensor`): + The cumulative product of alphas at the previous timestep. + alpha_prod_t_back (`torch.Tensor`, *optional*): + The cumulative product of alphas at the timestep before the previous timestep. + + Returns: + `tuple`: + A tuple containing the multipliers. + """ mult1 = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5 * (-h).exp() mult2 = (-2 * h).expm1() * alpha_prod_t_prev**0.5 @@ -338,13 +404,13 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin): def step( self, model_output: torch.Tensor, - old_pred_original_sample: torch.Tensor, + old_pred_original_sample: Optional[torch.Tensor], timestep: int, timestep_back: int, sample: torch.Tensor, eta: float = 0.0, use_clipped_model_output: bool = False, - generator=None, + generator: Optional[torch.Generator] = None, variance_noise: Optional[torch.Tensor] = None, return_dict: bool = False, ) -> Union[DDIMSchedulerOutput, Tuple]: @@ -355,8 +421,12 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin): Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. - timestep (`float`): + old_pred_original_sample (`torch.Tensor`): + The predicted original sample from the previous timestep. + timestep (`int`): The current discrete timestep in the diffusion chain. + timestep_back (`int`): + The timestep to look back to. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. eta (`float`): @@ -436,7 +506,12 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin): return prev_sample, pred_original_sample else: denoised_d = mult[2] * pred_original_sample - mult[3] * old_pred_original_sample - noise = randn_tensor(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype) + noise = randn_tensor( + sample.shape, + generator=generator, + device=sample.device, + dtype=sample.dtype, + ) x_advanced = mult[0] * sample - mult[1] * denoised_d + mult_noise * noise prev_sample = x_advanced @@ -524,5 +599,5 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin): velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity - def __len__(self): + def __len__(self) -> int: return self.config.num_train_timesteps From a58d0b9bec320722aaf193f269bf5c2ffd51d0d4 Mon Sep 17 00:00:00 2001 From: Jayce <93087577+Jayce-Ping@users.noreply.github.com> Date: Thu, 29 Jan 2026 10:06:58 +0800 Subject: [PATCH 3/4] Fix Wan/WanI2V patchification (#13038) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix Wan/WanI2V patchification * Apply style fixes * Apply suggestions from code review I agree with you for the idea of using `patch_size` instead. Thanks!😊 Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Fix logger warning --------- Co-authored-by: github-actions[bot] Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- src/diffusers/pipelines/wan/pipeline_wan.py | 11 +++++++++++ src/diffusers/pipelines/wan/pipeline_wan_i2v.py | 11 +++++++++++ 2 files changed, 22 insertions(+) diff --git a/src/diffusers/pipelines/wan/pipeline_wan.py b/src/diffusers/pipelines/wan/pipeline_wan.py index dc2bb47110..362440c8c9 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan.py +++ b/src/diffusers/pipelines/wan/pipeline_wan.py @@ -496,6 +496,17 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin): num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 num_frames = max(num_frames, 1) + h_multiple_of = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1] + w_multiple_of = self.vae_scale_factor_spatial * self.transformer.config.patch_size[2] + calc_height = height // h_multiple_of * h_multiple_of + calc_width = width // w_multiple_of * w_multiple_of + if height != calc_height or width != calc_width: + logger.warning( + f"`height` and `width` must be multiples of ({h_multiple_of}, {w_multiple_of}) for proper patchification. " + f"Adjusting ({height}, {width}) -> ({calc_height}, {calc_width})." + ) + height, width = calc_height, calc_width + if self.config.boundary_ratio is not None and guidance_scale_2 is None: guidance_scale_2 = guidance_scale diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py index b7fd0b0598..e22d8393bd 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py @@ -637,6 +637,17 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 num_frames = max(num_frames, 1) + h_multiple_of = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1] + w_multiple_of = self.vae_scale_factor_spatial * self.transformer.config.patch_size[2] + calc_height = height // h_multiple_of * h_multiple_of + calc_width = width // w_multiple_of * w_multiple_of + if height != calc_height or width != calc_width: + logger.warning( + f"`height` and `width` must be multiples of ({h_multiple_of}, {w_multiple_of}) for proper patchification. " + f"Adjusting ({height}, {width}) -> ({calc_height}, {calc_width})." + ) + height, width = calc_height, calc_width + if self.config.boundary_ratio is not None and guidance_scale_2 is None: guidance_scale_2 = guidance_scale From a2ea45a5da7f9d4b197ee5da7481768bf865f7d1 Mon Sep 17 00:00:00 2001 From: "Vinh H. Pham" Date: Thu, 29 Jan 2026 10:35:43 +0700 Subject: [PATCH 4/4] LTX2 distilled checkpoint support (#12934) * add constants for distill sigmas values and allow ltx pipeline to pass in sigmas * add time conditioning conversion and token packing for latents * make style & quality * remove prenorm * add sigma param to ltx2 i2v * fix copies and add pack latents to i2v * Apply suggestions from code review Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Infer latent dims if latents/audio_latents is supplied * add note for predefined sigmas * run make style and quality * revert distill timesteps & set original_state_dict_repo_idd to default None * add latent normalize * add create noised state, delete last sigmas * remove normalize step in latent upsample pipeline and move it to ltx2 pipeline * add create noise latent to i2v pipeline * fix copies * parse none value in weight conversion script * explicit shape handling * Apply suggestions from code review Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * make style * add two stage inference tests * add ltx2 documentation * update i2v expected_audio_slice * Apply suggestions from code review Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Apply suggestion from @dg845 Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update ltx2.md to remove one-stage example Removed one-stage generation example code and added comments for noise scale in two-stage generation. --------- Co-authored-by: Sayak Paul Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> Co-authored-by: Daniel Gu --- docs/source/en/api/pipelines/ltx2.md | 173 ++++++++++++++++++ scripts/convert_ltx2_to_diffusers.py | 38 ++-- .../autoencoders/autoencoder_kl_ltx2_audio.py | 4 +- src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 123 +++++++++++-- .../ltx2/pipeline_ltx2_image2video.py | 120 +++++++++--- .../ltx2/pipeline_ltx2_latent_upsample.py | 14 -- src/diffusers/pipelines/ltx2/utils.py | 6 + tests/pipelines/ltx2/test_ltx2.py | 52 +++++- tests/pipelines/ltx2/test_ltx2_image2video.py | 52 +++++- 9 files changed, 508 insertions(+), 74 deletions(-) create mode 100644 src/diffusers/pipelines/ltx2/utils.py diff --git a/docs/source/en/api/pipelines/ltx2.md b/docs/source/en/api/pipelines/ltx2.md index 4c6860daf0..24776b4230 100644 --- a/docs/source/en/api/pipelines/ltx2.md +++ b/docs/source/en/api/pipelines/ltx2.md @@ -24,6 +24,179 @@ You can find all the original LTX-Video checkpoints under the [Lightricks](https The original codebase for LTX-2 can be found [here](https://github.com/Lightricks/LTX-2). +## Two-stages Generation +Recommended pipeline to achieve production quality generation, this pipeline is composed of two stages: + +- Stage 1: Generate a video at the target resolution using diffusion sampling with classifier-free guidance (CFG). This stage produces a coherent low-noise video sequence that respects the text/image conditioning. +- Stage 2: Upsample the Stage 1 output by 2 and refine details using a distilled LoRA model to improve fidelity and visual quality. Stage 2 may apply lighter CFG to preserve the structure from Stage 1 while enhancing texture and sharpness. + +Sample usage of text-to-video two stages pipeline + +```py +import torch +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.pipelines.ltx2 import LTX2Pipeline, LTX2LatentUpsamplePipeline +from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel +from diffusers.pipelines.ltx2.utils import STAGE_2_DISTILLED_SIGMA_VALUES +from diffusers.pipelines.ltx2.export_utils import encode_video + +device = "cuda:0" +width = 768 +height = 512 + +pipe = LTX2Pipeline.from_pretrained( + "Lightricks/LTX-2", torch_dtype=torch.bfloat16 +) +pipe.enable_sequential_cpu_offload(device=device) + +prompt = "A beautiful sunset over the ocean" +negative_prompt = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static." + +# Stage 1 default (non-distilled) inference +frame_rate = 24.0 +video_latent, audio_latent = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + width=width, + height=height, + num_frames=121, + frame_rate=frame_rate, + num_inference_steps=40, + sigmas=None, + guidance_scale=4.0, + output_type="latent", + return_dict=False, +) + +latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained( + "Lightricks/LTX-2", + subfolder="latent_upsampler", + torch_dtype=torch.bfloat16, +) +upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler) +upsample_pipe.enable_model_cpu_offload(device=device) +upscaled_video_latent = upsample_pipe( + latents=video_latent, + output_type="latent", + return_dict=False, +)[0] + +# Load Stage 2 distilled LoRA +pipe.load_lora_weights( + "Lightricks/LTX-2", adapter_name="stage_2_distilled", weight_name="ltx-2-19b-distilled-lora-384.safetensors" +) +pipe.set_adapters("stage_2_distilled", 1.0) +# VAE tiling is usually necessary to avoid OOM error when VAE decoding +pipe.vae.enable_tiling() +# Change scheduler to use Stage 2 distilled sigmas as is +new_scheduler = FlowMatchEulerDiscreteScheduler.from_config( + pipe.scheduler.config, use_dynamic_shifting=False, shift_terminal=None +) +pipe.scheduler = new_scheduler +# Stage 2 inference with distilled LoRA and sigmas +video, audio = pipe( + latents=upscaled_video_latent, + audio_latents=audio_latent, + prompt=prompt, + negative_prompt=negative_prompt, + num_inference_steps=3, + noise_scale=STAGE_2_DISTILLED_SIGMA_VALUES[0], # renoise with first sigma value https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py#L218 + sigmas=STAGE_2_DISTILLED_SIGMA_VALUES, + guidance_scale=1.0, + output_type="np", + return_dict=False, +) +video = (video * 255).round().astype("uint8") +video = torch.from_numpy(video) + +encode_video( + video[0], + fps=frame_rate, + audio=audio[0].float().cpu(), + audio_sample_rate=pipe.vocoder.config.output_sampling_rate, + output_path="ltx2_lora_distilled_sample.mp4", +) +``` + +## Distilled checkpoint generation +Fastest two-stages generation pipeline using a distilled checkpoint. + +```py +import torch +from diffusers.pipelines.ltx2 import LTX2Pipeline, LTX2LatentUpsamplePipeline +from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel +from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES +from diffusers.pipelines.ltx2.export_utils import encode_video + +device = "cuda" +width = 768 +height = 512 +random_seed = 42 +generator = torch.Generator(device).manual_seed(random_seed) +model_path = "rootonchair/LTX-2-19b-distilled" + +pipe = LTX2Pipeline.from_pretrained( + model_path, torch_dtype=torch.bfloat16 +) +pipe.enable_sequential_cpu_offload(device=device) + +prompt = "A beautiful sunset over the ocean" +negative_prompt = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static." + +frame_rate = 24.0 +video_latent, audio_latent = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + width=width, + height=height, + num_frames=121, + frame_rate=frame_rate, + num_inference_steps=8, + sigmas=DISTILLED_SIGMA_VALUES, + guidance_scale=1.0, + generator=generator, + output_type="latent", + return_dict=False, +) + +latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained( + model_path, + subfolder="latent_upsampler", + torch_dtype=torch.bfloat16, +) +upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler) +upsample_pipe.enable_model_cpu_offload(device=device) +upscaled_video_latent = upsample_pipe( + latents=video_latent, + output_type="latent", + return_dict=False, +)[0] + +video, audio = pipe( + latents=upscaled_video_latent, + audio_latents=audio_latent, + prompt=prompt, + negative_prompt=negative_prompt, + num_inference_steps=3, + noise_scale=STAGE_2_DISTILLED_SIGMA_VALUES[0], # renoise with first sigma value https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-pipelines/src/ltx_pipelines/distilled.py#L178 + sigmas=STAGE_2_DISTILLED_SIGMA_VALUES, + generator=generator, + guidance_scale=1.0, + output_type="np", + return_dict=False, +) +video = (video * 255).round().astype("uint8") +video = torch.from_numpy(video) + +encode_video( + video[0], + fps=frame_rate, + audio=audio[0].float().cpu(), + audio_sample_rate=pipe.vocoder.config.output_sampling_rate, + output_path="ltx2_distilled_sample.mp4", +) +``` + ## LTX2Pipeline [[autodoc]] LTX2Pipeline diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index 5367113365..88a9c8700a 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -63,6 +63,8 @@ LTX_2_0_VIDEO_VAE_RENAME_DICT = { "up_blocks.4": "up_blocks.1", "up_blocks.5": "up_blocks.2.upsamplers.0", "up_blocks.6": "up_blocks.2", + "last_time_embedder": "time_embedder", + "last_scale_shift_table": "scale_shift_table", # Common # For all 3D ResNets "res_blocks": "resnets", @@ -372,7 +374,9 @@ def convert_ltx2_connectors(original_state_dict: Dict[str, Any], version: str) - return connectors -def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: +def get_ltx2_video_vae_config( + version: str, timestep_conditioning: bool = False +) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: if version == "test": config = { "model_id": "diffusers-internal-dev/dummy-ltx2", @@ -396,7 +400,7 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), "upsample_residual": (True, True, True), "upsample_factor": (2, 2, 2), - "timestep_conditioning": False, + "timestep_conditioning": timestep_conditioning, "patch_size": 4, "patch_size_t": 1, "resnet_norm_eps": 1e-6, @@ -433,7 +437,7 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), "upsample_residual": (True, True, True), "upsample_factor": (2, 2, 2), - "timestep_conditioning": False, + "timestep_conditioning": timestep_conditioning, "patch_size": 4, "patch_size_t": 1, "resnet_norm_eps": 1e-6, @@ -450,8 +454,10 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A return config, rename_dict, special_keys_remap -def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]: - config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version) +def convert_ltx2_video_vae( + original_state_dict: Dict[str, Any], version: str, timestep_conditioning: bool +) -> Dict[str, Any]: + config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version, timestep_conditioning) diffusers_config = config["diffusers_config"] with init_empty_weights(): @@ -659,10 +665,15 @@ def get_model_state_dict_from_combined_ckpt(combined_ckpt: Dict[str, Any], prefi def get_args(): parser = argparse.ArgumentParser() + def none_or_str(value: str): + if isinstance(value, str) and value.lower() == "none": + return None + return value + parser.add_argument( "--original_state_dict_repo_id", default="Lightricks/LTX-2", - type=str, + type=none_or_str, help="HF Hub repo id with LTX 2.0 checkpoint", ) parser.add_argument( @@ -682,7 +693,7 @@ def get_args(): parser.add_argument( "--combined_filename", default="ltx-2-19b-dev.safetensors", - type=str, + type=none_or_str, help="Filename for combined checkpoint with all LTX 2.0 models (VAE, DiT, etc.)", ) parser.add_argument("--vae_prefix", default="vae.", type=str) @@ -701,22 +712,25 @@ def get_args(): parser.add_argument( "--text_encoder_model_id", default="google/gemma-3-12b-it-qat-q4_0-unquantized", - type=str, + type=none_or_str, help="HF Hub id for the LTX 2.0 base text encoder model", ) parser.add_argument( "--tokenizer_id", default="google/gemma-3-12b-it-qat-q4_0-unquantized", - type=str, + type=none_or_str, help="HF Hub id for the LTX 2.0 text tokenizer", ) parser.add_argument( "--latent_upsampler_filename", default="ltx-2-spatial-upscaler-x2-1.0.safetensors", - type=str, + type=none_or_str, help="Latent upsampler filename", ) + parser.add_argument( + "--timestep_conditioning", action="store_true", help="Whether to add timestep condition to the video VAE model" + ) parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model") parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model") parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model") @@ -786,7 +800,9 @@ def main(args): original_vae_ckpt = load_hub_or_local_checkpoint(filename=args.vae_filename) elif combined_ckpt is not None: original_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vae_prefix) - vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version) + vae = convert_ltx2_video_vae( + original_vae_ckpt, version=args.version, timestep_conditioning=args.timestep_conditioning + ) if not args.full_pipeline and not args.upsample_pipeline: vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae")) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py index 6c9c7dce3d..b29629a29f 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py @@ -743,8 +743,8 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin): # Per-channel statistics for normalizing and denormalizing the latent representation. This statics is computed over # the entire dataset and stored in model's checkpoint under AudioVAE state_dict - latents_std = torch.zeros((base_channels,)) - latents_mean = torch.ones((base_channels,)) + latents_std = torch.ones((base_channels,)) + latents_mean = torch.zeros((base_channels,)) self.register_buffer("latents_mean", latents_mean, persistent=True) self.register_buffer("latents_std", latents_std, persistent=True) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index 9cf8479263..a92a7a2c88 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -584,6 +584,17 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) return latents + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2_image2video.LTX2ImageToVideoPipeline._normalize_latents + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Normalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + @staticmethod def _denormalize_latents( latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 @@ -594,12 +605,26 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): latents = latents * latents_std / scaling_factor + latents_mean return latents + @staticmethod + def _normalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents - latents_mean) / latents_std + @staticmethod def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): latents_mean = latents_mean.to(latents.device, latents.dtype) latents_std = latents_std.to(latents.device, latents.dtype) return (latents * latents_std) + latents_mean + @staticmethod + def _create_noised_state( + latents: torch.Tensor, noise_scale: Union[float, torch.Tensor], generator: Optional[torch.Generator] = None + ): + noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype) + noised_latents = noise_scale * noise + (1 - noise_scale) * latents + return noised_latents + @staticmethod def _pack_audio_latents( latents: torch.Tensor, patch_size: Optional[int] = None, patch_size_t: Optional[int] = None @@ -647,12 +672,26 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): height: int = 512, width: int = 768, num_frames: int = 121, + noise_scale: float = 0.0, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[torch.Generator] = None, latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: if latents is not None: + if latents.ndim == 5: + latents = self._normalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + # latents are of shape [B, C, F, H, W], need to be packed + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + if latents.ndim != 3: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is [batch_size, num_seq, num_features]." + ) + latents = self._create_noised_state(latents, noise_scale, generator) return latents.to(device=device, dtype=dtype) height = height // self.vae_spatial_compression_ratio @@ -677,29 +716,30 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): self, batch_size: int = 1, num_channels_latents: int = 8, + audio_latent_length: int = 1, # 1 is just a dummy value num_mel_bins: int = 64, - num_frames: int = 121, - frame_rate: float = 25.0, - sampling_rate: int = 16000, - hop_length: int = 160, + noise_scale: float = 0.0, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[torch.Generator] = None, latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: - duration_s = num_frames / frame_rate - latents_per_second = ( - float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio) - ) - latent_length = round(duration_s * latents_per_second) - if latents is not None: - return latents.to(device=device, dtype=dtype), latent_length + if latents.ndim == 4: + # latents are of shape [B, C, L, M], need to be packed + latents = self._pack_audio_latents(latents) + if latents.ndim != 3: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is [batch_size, num_seq, num_features]." + ) + latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std) + latents = self._create_noised_state(latents, noise_scale, generator) + return latents.to(device=device, dtype=dtype) # TODO: confirm whether this logic is correct latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio - shape = (batch_size, num_channels_latents, latent_length, latent_mel_bins) + shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -709,7 +749,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = self._pack_audio_latents(latents) - return latents, latent_length + return latents @property def guidance_scale(self): @@ -750,9 +790,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): num_frames: int = 121, frame_rate: float = 24.0, num_inference_steps: int = 40, + sigmas: Optional[List[float]] = None, timesteps: List[int] = None, guidance_scale: float = 4.0, guidance_rescale: float = 0.0, + noise_scale: float = 0.0, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, @@ -788,6 +830,10 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): num_inference_steps (`int`, *optional*, defaults to 40): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is @@ -804,6 +850,9 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when using zero terminal SNR. + noise_scale (`float`, *optional*, defaults to `0.0`): + The interpolation factor between random noise and denoised latents at each timestep. Applying noise to + the `latents` and `audio_latents` before continue denoising. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of videos to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): @@ -922,6 +971,21 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 latent_height = height // self.vae_spatial_compression_ratio latent_width = width // self.vae_spatial_compression_ratio + if latents is not None: + if latents.ndim == 5: + logger.info( + "Got latents of shape [batch_size, latent_dim, latent_frames, latent_height, latent_width], `latent_num_frames`, `latent_height`, `latent_width` will be inferred." + ) + _, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W] + elif latents.ndim == 3: + logger.warning( + f"You have supplied packed `latents` of shape {latents.shape}, so the latent dims cannot be" + f" inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct." + ) + else: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, latent_dim, latent_frames, latent_height, latent_width]." + ) video_sequence_length = latent_num_frames * latent_height * latent_width num_channels_latents = self.transformer.config.in_channels @@ -931,26 +995,45 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): height, width, num_frames, + noise_scale, torch.float32, device, generator, latents, ) + duration_s = num_frames / frame_rate + audio_latents_per_second = ( + self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio) + ) + audio_num_frames = round(duration_s * audio_latents_per_second) + if audio_latents is not None: + if audio_latents.ndim == 4: + logger.info( + "Got audio_latents of shape [batch_size, num_channels, audio_length, mel_bins], `audio_num_frames` will be inferred." + ) + _, _, audio_num_frames, _ = audio_latents.shape # [B, C, L, M] + elif audio_latents.ndim == 3: + logger.warning( + f"You have supplied packed `audio_latents` of shape {audio_latents.shape}, so the latent dims" + f" cannot be inferred. Make sure the supplied `num_frames` and `frame_rate` are correct." + ) + else: + raise ValueError( + f"Provided `audio_latents` tensor has shape {audio_latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, num_channels, audio_length, mel_bins]." + ) + num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64 latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio - num_channels_latents_audio = ( self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8 ) - audio_latents, audio_num_frames = self.prepare_audio_latents( + audio_latents = self.prepare_audio_latents( batch_size * num_videos_per_prompt, num_channels_latents=num_channels_latents_audio, + audio_latent_length=audio_num_frames, num_mel_bins=num_mel_bins, - num_frames=num_frames, # Video frames, audio frames will be calculated from this - frame_rate=frame_rate, - sampling_rate=self.audio_sampling_rate, - hop_length=self.audio_hop_length, + noise_scale=noise_scale, dtype=torch.float32, device=device, generator=generator, @@ -958,7 +1041,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): ) # 5. Prepare timesteps - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas mu = calculate_shift( video_sequence_length, self.scheduler.config.get("base_image_seq_len", 1024), diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py index b1711e2831..04d7ee89c5 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -614,6 +614,15 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL latents = latents * latents_std / scaling_factor + latents_mean return latents + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._create_noised_state + def _create_noised_state( + latents: torch.Tensor, noise_scale: Union[float, torch.Tensor], generator: Optional[torch.Generator] = None + ): + noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype) + noised_latents = noise_scale * noise + (1 - noise_scale) * latents + return noised_latents + @staticmethod # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_audio_latents def _pack_audio_latents( @@ -656,6 +665,13 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL latents = latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2) return latents + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._normalize_audio_latents + def _normalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents - latents_mean) / latents_std + @staticmethod # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_audio_latents def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): @@ -671,6 +687,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL height: int = 512, width: int = 704, num_frames: int = 161, + noise_scale: float = 0.0, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[torch.Generator] = None, @@ -686,6 +703,15 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL if latents is not None: conditioning_mask = latents.new_zeros(mask_shape) conditioning_mask[:, :, 0] = 1.0 + if latents.ndim == 5: + latents = self._normalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + latents = self._create_noised_state(latents, noise_scale * (1 - conditioning_mask), generator) + # latents are of shape [B, C, F, H, W], need to be packed + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) conditioning_mask = self._pack_latents( conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size ).squeeze(-1) @@ -737,29 +763,30 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL self, batch_size: int = 1, num_channels_latents: int = 8, + audio_latent_length: int = 1, # 1 is just a dummy value num_mel_bins: int = 64, - num_frames: int = 121, - frame_rate: float = 25.0, - sampling_rate: int = 16000, - hop_length: int = 160, + noise_scale: float = 0.0, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[torch.Generator] = None, latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: - duration_s = num_frames / frame_rate - latents_per_second = ( - float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio) - ) - latent_length = round(duration_s * latents_per_second) - if latents is not None: - return latents.to(device=device, dtype=dtype), latent_length + if latents.ndim == 4: + # latents are of shape [B, C, L, M], need to be packed + latents = self._pack_audio_latents(latents) + if latents.ndim != 3: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is [batch_size, num_seq, num_features]." + ) + latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std) + latents = self._create_noised_state(latents, noise_scale, generator) + return latents.to(device=device, dtype=dtype) # TODO: confirm whether this logic is correct latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio - shape = (batch_size, num_channels_latents, latent_length, latent_mel_bins) + shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -769,7 +796,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = self._pack_audio_latents(latents) - return latents, latent_length + return latents @property def guidance_scale(self): @@ -811,9 +838,11 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL num_frames: int = 121, frame_rate: float = 24.0, num_inference_steps: int = 40, + sigmas: Optional[List[float]] = None, timesteps: List[int] = None, guidance_scale: float = 4.0, guidance_rescale: float = 0.0, + noise_scale: float = 0.0, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, @@ -851,6 +880,10 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL num_inference_steps (`int`, *optional*, defaults to 40): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is @@ -867,6 +900,9 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when using zero terminal SNR. + noise_scale (`float`, *optional*, defaults to `0.0`): + The interpolation factor between random noise and denoised latents at each timestep. Applying noise to + the `latents` and `audio_latents` before continue denoising. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of videos to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): @@ -982,6 +1018,26 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL ) # 4. Prepare latent variables + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + if latents is not None: + if latents.ndim == 5: + logger.info( + "Got latents of shape [batch_size, latent_dim, latent_frames, latent_height, latent_width], `latent_num_frames`, `latent_height`, `latent_width` will be inferred." + ) + _, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W] + elif latents.ndim == 3: + logger.warning( + f"You have supplied packed `latents` of shape {latents.shape}, so the latent dims cannot be" + f" inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct." + ) + else: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, latent_dim, latent_frames, latent_height, latent_width]." + ) + video_sequence_length = latent_num_frames * latent_height * latent_width + if latents is None: image = self.video_processor.preprocess(image, height=height, width=width) image = image.to(device=device, dtype=prompt_embeds.dtype) @@ -994,6 +1050,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL height, width, num_frames, + noise_scale, torch.float32, device, generator, @@ -1002,20 +1059,38 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL if self.do_classifier_free_guidance: conditioning_mask = torch.cat([conditioning_mask, conditioning_mask]) + duration_s = num_frames / frame_rate + audio_latents_per_second = ( + self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio) + ) + audio_num_frames = round(duration_s * audio_latents_per_second) + if audio_latents is not None: + if audio_latents.ndim == 4: + logger.info( + "Got audio_latents of shape [batch_size, num_channels, audio_length, mel_bins], `audio_num_frames` will be inferred." + ) + _, _, audio_num_frames, _ = audio_latents.shape # [B, C, L, M] + elif audio_latents.ndim == 3: + logger.warning( + f"You have supplied packed `audio_latents` of shape {audio_latents.shape}, so the latent dims" + f" cannot be inferred. Make sure the supplied `num_frames` and `frame_rate` are correct." + ) + else: + raise ValueError( + f"Provided `audio_latents` tensor has shape {audio_latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, num_channels, audio_length, mel_bins]." + ) + num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64 latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio - num_channels_latents_audio = ( self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8 ) - audio_latents, audio_num_frames = self.prepare_audio_latents( + audio_latents = self.prepare_audio_latents( batch_size * num_videos_per_prompt, num_channels_latents=num_channels_latents_audio, + audio_latent_length=audio_num_frames, num_mel_bins=num_mel_bins, - num_frames=num_frames, # Video frames, audio frames will be calculated from this - frame_rate=frame_rate, - sampling_rate=self.audio_sampling_rate, - hop_length=self.audio_hop_length, + noise_scale=noise_scale, dtype=torch.float32, device=device, generator=generator, @@ -1023,12 +1098,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL ) # 5. Prepare timesteps - latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 - latent_height = height // self.vae_spatial_compression_ratio - latent_width = width // self.vae_spatial_compression_ratio - video_sequence_length = latent_num_frames * latent_height * latent_width - - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas mu = calculate_shift( video_sequence_length, self.scheduler.config.get("base_image_seq_len", 1024), diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py index a44c40b043..340efd10f2 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py @@ -228,17 +228,6 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline): filtered = latents * scales return filtered - @staticmethod - # Copied from diffusers.pipelines.ltx2.pipeline_ltx2_image2video.LTX2ImageToVideoPipeline._normalize_latents - def _normalize_latents( - latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 - ) -> torch.Tensor: - # Normalize latents across the channel dimension [B, C, F, H, W] - latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) - latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) - latents = (latents - latents_mean) * scaling_factor / latents_std - return latents - @staticmethod # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_latents def _denormalize_latents( @@ -408,9 +397,6 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline): latents = self.tone_map_latents(latents, tone_map_compression_ratio) if output_type == "latent": - latents = self._normalize_latents( - latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor - ) video = latents else: if not self.vae.config.timestep_conditioning: diff --git a/src/diffusers/pipelines/ltx2/utils.py b/src/diffusers/pipelines/ltx2/utils.py new file mode 100644 index 0000000000..f80469817f --- /dev/null +++ b/src/diffusers/pipelines/ltx2/utils.py @@ -0,0 +1,6 @@ +# Pre-trained sigma values for distilled model are taken from +# https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-pipelines/src/ltx_pipelines/utils/constants.py +DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875] + +# Reduced schedule for super-resolution stage 2 (subset of distilled values) +STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875] diff --git a/tests/pipelines/ltx2/test_ltx2.py b/tests/pipelines/ltx2/test_ltx2.py index 6ffc237250..7d1a3bfc99 100644 --- a/tests/pipelines/ltx2/test_ltx2.py +++ b/tests/pipelines/ltx2/test_ltx2.py @@ -222,7 +222,57 @@ class LTX2PipelineFastTests(PipelineTesterMixin, unittest.TestCase): ) expected_audio_slice = torch.tensor( [ - 0.0236, 0.0499, 0.1230, 0.1094, 0.1713, 0.1044, 0.1729, 0.1009, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739 + 0.0263, 0.0528, 0.1217, 0.1104, 0.1632, 0.1072, 0.1789, 0.0949, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739 + ] + ) + # fmt: on + + video = video.flatten() + audio = audio.flatten() + generated_video_slice = torch.cat([video[:8], video[-8:]]) + generated_audio_slice = torch.cat([audio[:8], audio[-8:]]) + + assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4) + assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4) + + def test_two_stages_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["output_type"] = "latent" + first_stage_output = pipe(**inputs) + video_latent = first_stage_output.frames + audio_latent = first_stage_output.audio + + self.assertEqual(video_latent.shape, (1, 4, 3, 16, 16)) + self.assertEqual(audio_latent.shape, (1, 2, 5, 2)) + self.assertEqual(audio_latent.shape[1], components["vocoder"].config.out_channels) + + inputs["latents"] = video_latent + inputs["audio_latents"] = audio_latent + inputs["output_type"] = "pt" + second_stage_output = pipe(**inputs) + video = second_stage_output.frames + audio = second_stage_output.audio + + self.assertEqual(video.shape, (1, 5, 3, 32, 32)) + self.assertEqual(audio.shape[0], 1) + self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels) + + # fmt: off + expected_video_slice = torch.tensor( + [ + 0.5514, 0.5943, 0.4260, 0.5971, 0.4306, 0.6369, 0.3124, 0.6964, 0.5419, 0.2412, 0.3882, 0.4504, 0.1941, 0.3404, 0.6037, 0.2464 + ] + ) + expected_audio_slice = torch.tensor( + [ + 0.0252, 0.0526, 0.1211, 0.1119, 0.1638, 0.1042, 0.1776, 0.0948, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739 ] ) # fmt: on diff --git a/tests/pipelines/ltx2/test_ltx2_image2video.py b/tests/pipelines/ltx2/test_ltx2_image2video.py index 1edae9c0e0..3653e1cfc5 100644 --- a/tests/pipelines/ltx2/test_ltx2_image2video.py +++ b/tests/pipelines/ltx2/test_ltx2_image2video.py @@ -224,7 +224,57 @@ class LTX2ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ) expected_audio_slice = torch.tensor( [ - 0.0236, 0.0499, 0.1230, 0.1094, 0.1713, 0.1044, 0.1729, 0.1009, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739 + 0.0294, 0.0498, 0.1269, 0.1135, 0.1639, 0.1116, 0.1730, 0.0931, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739 + ] + ) + # fmt: on + + video = video.flatten() + audio = audio.flatten() + generated_video_slice = torch.cat([video[:8], video[-8:]]) + generated_audio_slice = torch.cat([audio[:8], audio[-8:]]) + + assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4) + assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4) + + def test_two_stages_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["output_type"] = "latent" + first_stage_output = pipe(**inputs) + video_latent = first_stage_output.frames + audio_latent = first_stage_output.audio + + self.assertEqual(video_latent.shape, (1, 4, 3, 16, 16)) + self.assertEqual(audio_latent.shape, (1, 2, 5, 2)) + self.assertEqual(audio_latent.shape[1], components["vocoder"].config.out_channels) + + inputs["latents"] = video_latent + inputs["audio_latents"] = audio_latent + inputs["output_type"] = "pt" + second_stage_output = pipe(**inputs) + video = second_stage_output.frames + audio = second_stage_output.audio + + self.assertEqual(video.shape, (1, 5, 3, 32, 32)) + self.assertEqual(audio.shape[0], 1) + self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels) + + # fmt: off + expected_video_slice = torch.tensor( + [ + 0.2665, 0.6915, 0.2939, 0.6767, 0.2552, 0.6215, 0.1765, 0.6248, 0.2800, 0.2356, 0.3480, 0.5395, 0.3190, 0.4128, 0.4784, 0.4086 + ] + ) + expected_audio_slice = torch.tensor( + [ + 0.0273, 0.0490, 0.1253, 0.1129, 0.1655, 0.1057, 0.1707, 0.0943, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739 ] ) # fmt: on