Compare commits

...

13 Commits

Author SHA1 Message Date
Aryan
01dae7492b Merge branch 'main' into integrations/ltx-098 2025-09-10 05:12:54 +05:30
Aryan
83bda524fb Revert "try generating in reverse like... like what seems to be done in original codebase"
This reverts commit 0e97669edb.
2025-08-14 03:36:32 +02:00
Aryan
0e97669edb try generating in reverse like... like what seems to be done in original codebase 2025-08-14 03:36:10 +02:00
Aryan
f2264e813f try manually writing logic that kinda makes sense: 2025-08-14 02:56:54 +02:00
Aryan
322d03d6a0 make style 2025-08-14 01:34:53 +02:00
Aryan
3de88be38a refactor 3 2025-08-14 01:34:37 +02:00
Aryan
84c17be0f0 refactor 2 2025-08-14 01:26:39 +02:00
Aryan
47bf390bb4 refactor 1 2025-08-14 01:12:13 +02:00
Aryan
e981399b81 simplify 2025-08-14 01:09:43 +02:00
Aryan
27a43451cb this version kinda works but results still bad 2025-08-13 22:06:31 +02:00
Aryan
16b82a546f update progress 2025-08-13 00:34:23 +02:00
Aryan
13dbe8cdb2 Merge branch 'main' into integrations/ltx-098 2025-08-07 17:32:12 +05:30
Aryan
47d93ce35e update 2025-08-07 13:58:45 +02:00
7 changed files with 1353 additions and 6 deletions

View File

@@ -254,8 +254,8 @@ export_to_video(video, "output.mp4", fps=24)
pipeline.vae.enable_tiling()
def round_to_nearest_resolution_acceptable_by_vae(height, width):
height = height - (height % pipeline.vae_temporal_compression_ratio)
width = width - (width % pipeline.vae_temporal_compression_ratio)
height = height - (height % pipeline.vae_spatial_compression_ratio)
width = width - (width % pipeline.vae_spatial_compression_ratio)
return height, width
prompt = """
@@ -325,6 +325,95 @@ export_to_video(video, "output.mp4", fps=24)
</details>
- LTX-Video 0.9.8 distilled model is similar to the 0.9.7 variant. It is guidance and timestep-distilled, and similar inference code can be used as above. An improvement of this version is that it supports generating very long videos. Additionally, it supports using tone mapping to improve the quality of the generated video using the `tone_map_compression_ratio` parameter. The default value of `0.6` is recommended.
<details>
<summary>Show example code</summary>
```python
import torch
from diffusers import LTXConditionPipeline, LTXLatentUpsamplePipeline
from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition
from diffusers.pipelines.ltx.modeling_latent_upsampler import LTXLatentUpsamplerModel
from diffusers.utils import export_to_video, load_video
pipeline = LTXConditionPipeline.from_pretrained("Lightricks/LTX-Video-0.9.8-13B-distilled", torch_dtype=torch.bfloat16)
# TODO: Update the checkpoint here once updated in LTX org
upsampler = LTXLatentUpsamplerModel.from_pretrained("a-r-r-o-w/LTX-0.9.8-Latent-Upsampler", torch_dtype=torch.bfloat16)
pipe_upsample = LTXLatentUpsamplePipeline(vae=pipeline.vae, latent_upsampler=upsampler).to(torch.bfloat16)
pipeline.to("cuda")
pipe_upsample.to("cuda")
pipeline.vae.enable_tiling()
def round_to_nearest_resolution_acceptable_by_vae(height, width):
height = height - (height % pipeline.vae_spatial_compression_ratio)
width = width - (width % pipeline.vae_spatial_compression_ratio)
return height, width
prompt = """The camera pans over a snow-covered mountain range, revealing a vast expanse of snow-capped peaks and valleys.The mountains are covered in a thick layer of snow, with some areas appearing almost white while others have a slightly darker, almost grayish hue. The peaks are jagged and irregular, with some rising sharply into the sky while others are more rounded. The valleys are deep and narrow, with steep slopes that are also covered in snow. The trees in the foreground are mostly bare, with only a few leaves remaining on their branches. The sky is overcast, with thick clouds obscuring the sun. The overall impression is one of peace and tranquility, with the snow-covered mountains standing as a testament to the power and beauty of nature."""
# prompt = """A woman walks away from a white Jeep parked on a city street at night, then ascends a staircase and knocks on a door. The woman, wearing a dark jacket and jeans, walks away from the Jeep parked on the left side of the street, her back to the camera; she walks at a steady pace, her arms swinging slightly by her sides; the street is dimly lit, with streetlights casting pools of light on the wet pavement; a man in a dark jacket and jeans walks past the Jeep in the opposite direction; the camera follows the woman from behind as she walks up a set of stairs towards a building with a green door; she reaches the top of the stairs and turns left, continuing to walk towards the building; she reaches the door and knocks on it with her right hand; the camera remains stationary, focused on the doorway; the scene is captured in real-life footage."""
negative_prompt = "bright colors, symbols, graffiti, watermarks, worst quality, inconsistent motion, blurry, jittery, distorted"
expected_height, expected_width = 480, 832
downscale_factor = 2 / 3
# num_frames = 161
num_frames = 361
# 1. Generate video at smaller resolution
downscaled_height, downscaled_width = int(expected_height * downscale_factor), int(expected_width * downscale_factor)
downscaled_height, downscaled_width = round_to_nearest_resolution_acceptable_by_vae(downscaled_height, downscaled_width)
latents = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
width=downscaled_width,
height=downscaled_height,
num_frames=num_frames,
timesteps=[1000, 993, 987, 981, 975, 909, 725, 0.03],
decode_timestep=0.05,
decode_noise_scale=0.025,
image_cond_noise_scale=0.0,
guidance_scale=1.0,
guidance_rescale=0.7,
generator=torch.Generator().manual_seed(0),
output_type="latent",
).frames
# 2. Upscale generated video using latent upsampler with fewer inference steps
# The available latent upsampler upscales the height/width by 2x
upscaled_height, upscaled_width = downscaled_height * 2, downscaled_width * 2
upscaled_latents = pipe_upsample(
latents=latents,
adain_factor=1.0,
tone_map_compression_ratio=0.6,
output_type="latent"
).frames
# 3. Denoise the upscaled video with few steps to improve texture (optional, but recommended)
video = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
width=upscaled_width,
height=upscaled_height,
num_frames=num_frames,
denoise_strength=0.999, # Effectively, 4 inference steps out of 5
timesteps=[1000, 909, 725, 421, 0],
latents=upscaled_latents,
decode_timestep=0.05,
decode_noise_scale=0.025,
image_cond_noise_scale=0.0,
guidance_scale=1.0,
guidance_rescale=0.7,
generator=torch.Generator().manual_seed(0),
output_type="pil",
).frames[0]
# 4. Downscale the video to the expected resolution
video = [frame.resize((expected_width, expected_height)) for frame in video]
export_to_video(video, "output.mp4", fps=24)
```
</details>
- LTX-Video supports LoRAs with [`~loaders.LTXVideoLoraLoaderMixin.load_lora_weights`].
<details>

View File

@@ -369,6 +369,15 @@ def get_spatial_latent_upsampler_config(version: str) -> Dict[str, Any]:
"spatial_upsample": True,
"temporal_upsample": False,
}
elif version == "0.9.8":
config = {
"in_channels": 128,
"mid_channels": 512,
"num_blocks_per_stage": 4,
"dims": 3,
"spatial_upsample": True,
"temporal_upsample": False,
}
else:
raise ValueError(f"Unsupported version: {version}")
return config
@@ -402,7 +411,7 @@ def get_args():
"--version",
type=str,
default="0.9.0",
choices=["0.9.0", "0.9.1", "0.9.5", "0.9.7"],
choices=["0.9.0", "0.9.1", "0.9.5", "0.9.7", "0.9.8"],
help="Version of the LTX model",
)
return parser.parse_args()

View File

@@ -491,6 +491,7 @@ else:
"LDMTextToImagePipeline",
"LEditsPPPipelineStableDiffusion",
"LEditsPPPipelineStableDiffusionXL",
"LTXConditionInfinitePipeline",
"LTXConditionPipeline",
"LTXImageToVideoPipeline",
"LTXLatentUpsamplePipeline",
@@ -1145,6 +1146,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LDMTextToImagePipeline,
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
LTXConditionInfinitePipeline,
LTXConditionPipeline,
LTXImageToVideoPipeline,
LTXLatentUpsamplePipeline,

View File

@@ -281,6 +281,7 @@ else:
"LTXPipeline",
"LTXImageToVideoPipeline",
"LTXConditionPipeline",
"LTXConditionInfinitePipeline",
"LTXLatentUpsamplePipeline",
]
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
@@ -681,7 +682,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
)
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline
from .ltx import (
LTXConditionInfinitePipeline,
LTXConditionPipeline,
LTXImageToVideoPipeline,
LTXLatentUpsamplePipeline,
LTXPipeline,
)
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
from .marigold import (

View File

@@ -25,6 +25,7 @@ else:
_import_structure["modeling_latent_upsampler"] = ["LTXLatentUpsamplerModel"]
_import_structure["pipeline_ltx"] = ["LTXPipeline"]
_import_structure["pipeline_ltx_condition"] = ["LTXConditionPipeline"]
_import_structure["pipeline_ltx_condition_infinite"] = ["LTXConditionInfinitePipeline"]
_import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"]
_import_structure["pipeline_ltx_latent_upsample"] = ["LTXLatentUpsamplePipeline"]
@@ -39,6 +40,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .modeling_latent_upsampler import LTXLatentUpsamplerModel
from .pipeline_ltx import LTXPipeline
from .pipeline_ltx_condition import LTXConditionPipeline
from .pipeline_ltx_condition_infinite import LTXConditionInfinitePipeline
from .pipeline_ltx_image2video import LTXImageToVideoPipeline
from .pipeline_ltx_latent_upsample import LTXLatentUpsamplePipeline

File diff suppressed because it is too large Load Diff

View File

@@ -93,7 +93,8 @@ class LTXLatentUpsamplePipeline(DiffusionPipeline):
init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std)
return init_latents
def adain_filter_latent(self, latents: torch.Tensor, reference_latents: torch.Tensor, factor: float = 1.0):
@staticmethod
def adain_filter_latent(latents: torch.Tensor, reference_latents: torch.Tensor, factor: float = 1.0):
"""
Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on statistics from a reference latent
tensor.
@@ -121,6 +122,39 @@ class LTXLatentUpsamplePipeline(DiffusionPipeline):
result = torch.lerp(latents, result, factor)
return result
@staticmethod
def tone_map_latents(latents: torch.Tensor, compression: float) -> torch.Tensor:
"""
Applies a non-linear tone-mapping function to latent values to reduce their dynamic range in a perceptually
smooth way using a sigmoid-based compression.
This is useful for regularizing high-variance latents or for conditioning outputs during generation, especially
when controlling dynamic behavior with a `compression` factor.
Args:
latents : torch.Tensor
Input latent tensor with arbitrary shape. Expected to be roughly in [-1, 1] or [0, 1] range.
compression : float
Compression strength in the range [0, 1].
- 0.0: No tone-mapping (identity transform)
- 1.0: Full compression effect
Returns:
torch.Tensor
The tone-mapped latent tensor of the same shape as input.
"""
# Remap [0-1] to [0-0.75] and apply sigmoid compression in one shot
scale_factor = compression * 0.75
abs_latents = torch.abs(latents)
# Sigmoid compression: sigmoid shifts large values toward 0.2, small values stay ~1.0
# When scale_factor=0, sigmoid term vanishes, when scale_factor=0.75, full effect
sigmoid_term = torch.sigmoid(4.0 * scale_factor * (abs_latents - 1.0))
scales = 1.0 - 0.8 * scale_factor * sigmoid_term
filtered = latents * scales
return filtered
@staticmethod
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._normalize_latents
def _normalize_latents(
@@ -172,7 +206,7 @@ class LTXLatentUpsamplePipeline(DiffusionPipeline):
"""
self.vae.disable_tiling()
def check_inputs(self, video, height, width, latents):
def check_inputs(self, video, height, width, latents, tone_map_compression_ratio):
if height % self.vae_spatial_compression_ratio != 0 or width % self.vae_spatial_compression_ratio != 0:
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
@@ -181,6 +215,9 @@ class LTXLatentUpsamplePipeline(DiffusionPipeline):
if video is None and latents is None:
raise ValueError("One of `video` or `latents` has to be provided.")
if not (0 <= tone_map_compression_ratio <= 1):
raise ValueError("`tone_map_compression_ratio` must be in the range [0, 1]")
@torch.no_grad()
def __call__(
self,
@@ -191,6 +228,7 @@ class LTXLatentUpsamplePipeline(DiffusionPipeline):
decode_timestep: Union[float, List[float]] = 0.0,
decode_noise_scale: Optional[Union[float, List[float]]] = None,
adain_factor: float = 0.0,
tone_map_compression_ratio: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
@@ -200,6 +238,7 @@ class LTXLatentUpsamplePipeline(DiffusionPipeline):
height=height,
width=width,
latents=latents,
tone_map_compression_ratio=tone_map_compression_ratio,
)
if video is not None:
@@ -242,6 +281,9 @@ class LTXLatentUpsamplePipeline(DiffusionPipeline):
else:
latents = latents_upsampled
if tone_map_compression_ratio > 0.0:
latents = self.tone_map_latents(latents, tone_map_compression_ratio)
if output_type == "latent":
latents = self._normalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor