Compare commits

...

9 Commits

Author SHA1 Message Date
yiyixuxu
658d533144 up 2025-03-06 02:43:50 +01:00
Aryan
9fa964b230 update conversion script; todo: update 0.9.1 checkpoint with timestep_scale_multiplier 2025-03-05 21:53:47 +01:00
Aryan
52d2ec35ed vae fix 2025-03-05 21:53:24 +01:00
yiyixuxu
661ab0d781 yiyi add testing lines 2025-03-05 21:53:20 +01:00
Aryan
f950ba1da6 update 2025-03-05 20:35:44 +01:00
yiyixuxu
81f24686a6 up 2025-03-05 20:35:13 +01:00
Aryan
d0bdf4bc84 update 2025-03-05 16:03:28 +01:00
Aryan
f35b8077bc update 2025-03-05 15:42:52 +01:00
Aryan
ea436c4ba2 update 2025-03-05 01:36:30 +01:00
6 changed files with 1407 additions and 31 deletions

View File

@@ -74,6 +74,32 @@ VAE_091_RENAME_DICT = {
"last_scale_shift_table": "scale_shift_table",
}
VAE_095_RENAME_DICT = {
# decoder
"up_blocks.0": "mid_block",
"up_blocks.1": "up_blocks.0.upsamplers.0",
"up_blocks.2": "up_blocks.0",
"up_blocks.3": "up_blocks.1.upsamplers.0",
"up_blocks.4": "up_blocks.1",
"up_blocks.5": "up_blocks.2.upsamplers.0",
"up_blocks.6": "up_blocks.2",
"up_blocks.7": "up_blocks.3.upsamplers.0",
"up_blocks.8": "up_blocks.3",
# encoder
"down_blocks.0": "down_blocks.0",
"down_blocks.1": "down_blocks.0.downsamplers.0",
"down_blocks.2": "down_blocks.1",
"down_blocks.3": "down_blocks.1.downsamplers.0",
"down_blocks.4": "down_blocks.2",
"down_blocks.5": "down_blocks.2.downsamplers.0",
"down_blocks.6": "down_blocks.3",
"down_blocks.7": "down_blocks.3.downsamplers.0",
"down_blocks.8": "mid_block",
# common
"last_time_embedder": "time_embedder",
"last_scale_shift_table": "scale_shift_table",
}
VAE_SPECIAL_KEYS_REMAP = {
"per_channel_statistics.channel": remove_keys_,
"per_channel_statistics.mean-of-means": remove_keys_,
@@ -81,10 +107,6 @@ VAE_SPECIAL_KEYS_REMAP = {
"model.diffusion_model": remove_keys_,
}
VAE_091_SPECIAL_KEYS_REMAP = {
"timestep_scale_multiplier": remove_keys_,
}
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
state_dict = saved_dict
@@ -104,12 +126,16 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
def convert_transformer(
ckpt_path: str,
dtype: torch.dtype,
version: str = "0.9.0",
):
PREFIX_KEY = "model.diffusion_model."
original_state_dict = get_state_dict(load_file(ckpt_path))
config = {}
if version == "0.9.5":
config["_use_causal_rope_fix"] = True
with init_empty_weights():
transformer = LTXVideoTransformer3DModel()
transformer = LTXVideoTransformer3DModel(**config)
for key in list(original_state_dict.keys()):
new_key = key[:]
@@ -161,12 +187,19 @@ def get_vae_config(version: str) -> Dict[str, Any]:
"out_channels": 3,
"latent_channels": 128,
"block_out_channels": (128, 256, 512, 512),
"down_block_types": (
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
),
"decoder_block_out_channels": (128, 256, 512, 512),
"layers_per_block": (4, 3, 3, 3, 4),
"decoder_layers_per_block": (4, 3, 3, 3, 4),
"spatio_temporal_scaling": (True, True, True, False),
"decoder_spatio_temporal_scaling": (True, True, True, False),
"decoder_inject_noise": (False, False, False, False, False),
"downsample_type": ("conv", "conv", "conv", "conv"),
"upsample_residual": (False, False, False, False),
"upsample_factor": (1, 1, 1, 1),
"patch_size": 4,
@@ -183,12 +216,19 @@ def get_vae_config(version: str) -> Dict[str, Any]:
"out_channels": 3,
"latent_channels": 128,
"block_out_channels": (128, 256, 512, 512),
"down_block_types": (
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
),
"decoder_block_out_channels": (256, 512, 1024),
"layers_per_block": (4, 3, 3, 3, 4),
"decoder_layers_per_block": (5, 6, 7, 8),
"spatio_temporal_scaling": (True, True, True, False),
"decoder_spatio_temporal_scaling": (True, True, True),
"decoder_inject_noise": (True, True, True, False),
"downsample_type": ("conv", "conv", "conv", "conv"),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": True,
@@ -200,7 +240,36 @@ def get_vae_config(version: str) -> Dict[str, Any]:
"decoder_causal": False,
}
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
VAE_SPECIAL_KEYS_REMAP.update(VAE_091_SPECIAL_KEYS_REMAP)
elif version == "0.9.5":
config = {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"block_out_channels": (128, 256, 512, 1024, 2048),
"down_block_types": (
"LTXVideo095DownBlock3D",
"LTXVideo095DownBlock3D",
"LTXVideo095DownBlock3D",
"LTXVideo095DownBlock3D",
),
"decoder_block_out_channels": (256, 512, 1024),
"layers_per_block": (4, 6, 6, 2, 2),
"decoder_layers_per_block": (5, 5, 5, 5),
"spatio_temporal_scaling": (True, True, True, True),
"decoder_spatio_temporal_scaling": (True, True, True),
"decoder_inject_noise": (False, False, False, False),
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": True,
"patch_size": 4,
"patch_size_t": 1,
"resnet_norm_eps": 1e-6,
"scaling_factor": 1.0,
"encoder_causal": True,
"decoder_causal": False,
}
VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
return config
@@ -223,7 +292,7 @@ def get_args():
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
parser.add_argument(
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1"], help="Version of the LTX model"
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1", "0.9.5"], help="Version of the LTX model"
)
return parser.parse_args()

View File

@@ -196,6 +196,55 @@ class LTXVideoResnetBlock3d(nn.Module):
return hidden_states
class LTXVideoDownsampler3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
stride: Union[int, Tuple[int, int, int]] = 1,
is_causal: bool = True,
padding_mode: str = "zeros",
) -> None:
super().__init__()
self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
self.group_size = (in_channels * stride[0] * stride[1] * stride[2]) // out_channels
out_channels = out_channels // (self.stride[0] * self.stride[1] * self.stride[2])
self.conv = LTXVideoCausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
is_causal=is_causal,
padding_mode=padding_mode,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = torch.cat([hidden_states[:, :, : self.stride[0] - 1], hidden_states], dim=2)
residual = (
hidden_states.unflatten(4, (-1, self.stride[2]))
.unflatten(3, (-1, self.stride[1]))
.unflatten(2, (-1, self.stride[0]))
)
residual = residual.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4)
residual = residual.unflatten(1, (-1, self.group_size))
residual = residual.mean(dim=2)
hidden_states = self.conv(hidden_states)
hidden_states = (
hidden_states.unflatten(4, (-1, self.stride[2]))
.unflatten(3, (-1, self.stride[1]))
.unflatten(2, (-1, self.stride[0]))
)
hidden_states = hidden_states.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4)
hidden_states = hidden_states + residual
return hidden_states
class LTXVideoUpsampler3d(nn.Module):
def __init__(
self,
@@ -204,6 +253,7 @@ class LTXVideoUpsampler3d(nn.Module):
is_causal: bool = True,
residual: bool = False,
upscale_factor: int = 1,
padding_mode: str = "zeros",
) -> None:
super().__init__()
@@ -219,6 +269,7 @@ class LTXVideoUpsampler3d(nn.Module):
kernel_size=3,
stride=1,
is_causal=is_causal,
padding_mode=padding_mode,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -352,6 +403,120 @@ class LTXVideoDownBlock3D(nn.Module):
return hidden_states
class LTXVideo095DownBlock3D(nn.Module):
r"""
Down block used in the LTXVideo model.
Args:
in_channels (`int`):
Number of input channels.
out_channels (`int`, *optional*):
Number of output channels. If None, defaults to `in_channels`.
num_layers (`int`, defaults to `1`):
Number of resnet layers.
dropout (`float`, defaults to `0.0`):
Dropout rate.
resnet_eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
resnet_act_fn (`str`, defaults to `"swish"`):
Activation function to use.
spatio_temporal_scale (`bool`, defaults to `True`):
Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
Whether or not to downsample across temporal dimension.
is_causal (`bool`, defaults to `True`):
Whether this layer behaves causally (future frames depend only on past frames) or not.
"""
_supports_gradient_checkpointing = True
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
resnet_eps: float = 1e-6,
resnet_act_fn: str = "swish",
spatio_temporal_scale: bool = True,
is_causal: bool = True,
downsample_type: str = "conv",
):
super().__init__()
out_channels = out_channels or in_channels
resnets = []
for _ in range(num_layers):
resnets.append(
LTXVideoResnetBlock3d(
in_channels=in_channels,
out_channels=in_channels,
dropout=dropout,
eps=resnet_eps,
non_linearity=resnet_act_fn,
is_causal=is_causal,
)
)
self.resnets = nn.ModuleList(resnets)
self.downsamplers = None
if spatio_temporal_scale:
self.downsamplers = nn.ModuleList()
if downsample_type == "conv":
self.downsamplers.append(
LTXVideoCausalConv3d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=3,
stride=(2, 2, 2),
is_causal=is_causal,
)
)
elif downsample_type == "spatial":
self.downsamplers.append(
LTXVideoDownsampler3d(
in_channels=in_channels, out_channels=out_channels, stride=(1, 2, 2), is_causal=is_causal
)
)
elif downsample_type == "temporal":
self.downsamplers.append(
LTXVideoDownsampler3d(
in_channels=in_channels, out_channels=out_channels, stride=(2, 1, 1), is_causal=is_causal
)
)
elif downsample_type == "spatiotemporal":
self.downsamplers.append(
LTXVideoDownsampler3d(
in_channels=in_channels, out_channels=out_channels, stride=(2, 2, 2), is_causal=is_causal
)
)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
generator: Optional[torch.Generator] = None,
) -> torch.Tensor:
r"""Forward method of the `LTXDownBlock3D` class."""
for i, resnet in enumerate(self.resnets):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator)
else:
hidden_states = resnet(hidden_states, temb, generator)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
print(f" after downsampler: {hidden_states.shape}, {hidden_states[0,0,:3,:3,:3]}")
return hidden_states
# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d
class LTXVideoMidBlock3d(nn.Module):
r"""
@@ -593,8 +758,15 @@ class LTXVideoEncoder3d(nn.Module):
in_channels: int = 3,
out_channels: int = 128,
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
down_block_types: Tuple[str, ...] = (
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
),
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
downsample_type: Tuple[str, ...] = ("conv", "conv", "conv", "conv"),
patch_size: int = 4,
patch_size_t: int = 1,
resnet_norm_eps: float = 1e-6,
@@ -617,20 +789,37 @@ class LTXVideoEncoder3d(nn.Module):
)
# down blocks
num_block_out_channels = len(block_out_channels)
is_ltx_095 = down_block_types[-1] == "LTXVideo095DownBlock3D"
num_block_out_channels = len(block_out_channels) - (1 if is_ltx_095 else 0)
self.down_blocks = nn.ModuleList([])
for i in range(num_block_out_channels):
input_channel = output_channel
output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i]
if not is_ltx_095:
output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i]
else:
output_channel = block_out_channels[i + 1]
down_block = LTXVideoDownBlock3D(
in_channels=input_channel,
out_channels=output_channel,
num_layers=layers_per_block[i],
resnet_eps=resnet_norm_eps,
spatio_temporal_scale=spatio_temporal_scaling[i],
is_causal=is_causal,
)
if down_block_types[i] == "LTXVideoDownBlock3D":
down_block = LTXVideoDownBlock3D(
in_channels=input_channel,
out_channels=output_channel,
num_layers=layers_per_block[i],
resnet_eps=resnet_norm_eps,
spatio_temporal_scale=spatio_temporal_scaling[i],
is_causal=is_causal,
)
elif down_block_types[i] == "LTXVideo095DownBlock3D":
down_block = LTXVideo095DownBlock3D(
in_channels=input_channel,
out_channels=output_channel,
num_layers=layers_per_block[i],
resnet_eps=resnet_norm_eps,
spatio_temporal_scale=spatio_temporal_scaling[i],
is_causal=is_causal,
downsample_type=downsample_type[i],
)
else:
raise ValueError(f"Unknown down block type: {down_block_types[i]}")
self.down_blocks.append(down_block)
@@ -654,6 +843,8 @@ class LTXVideoEncoder3d(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
r"""The forward method of the `LTXVideoEncoder3d` class."""
print(f" inside LTXVideoEncoder3d")
print(f" hidden_states: {hidden_states.shape}, {hidden_states[0,0,:3,:3,:3]}")
p = self.patch_size
p_t = self.patch_size_t
@@ -667,7 +858,9 @@ class LTXVideoEncoder3d(nn.Module):
)
# Thanks for driving me insane with the weird patching order :(
hidden_states = hidden_states.permute(0, 1, 3, 7, 5, 2, 4, 6).flatten(1, 4)
print(f" before conv_in: {hidden_states.shape}, {hidden_states[0,0,:3,:3,:3]}")
hidden_states = self.conv_in(hidden_states)
print(f" after conv_in: {hidden_states.shape}, {hidden_states[0,0,:3,:3,:3]}")
if torch.is_grad_enabled() and self.gradient_checkpointing:
for down_block in self.down_blocks:
@@ -677,17 +870,22 @@ class LTXVideoEncoder3d(nn.Module):
else:
for down_block in self.down_blocks:
hidden_states = down_block(hidden_states)
print(f" after down_block: {hidden_states.shape}, {hidden_states[0,0,:3,:3,:3]}")
hidden_states = self.mid_block(hidden_states)
print(f" after mid_block: {hidden_states.shape}, {hidden_states[0,0,:3,:3,:3]}")
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
print(f" before conv_act: {hidden_states.shape}, {hidden_states[0,0,:3,:3,:3]}")
hidden_states = self.conv_act(hidden_states)
print(f" after conv_act: {hidden_states.shape}, {hidden_states[0,0,:3,:3,:3]}")
hidden_states = self.conv_out(hidden_states)
print(f" after conv_out: {hidden_states.shape}, {hidden_states[0,0,:3,:3,:3]}")
last_channel = hidden_states[:, -1:]
last_channel = last_channel.repeat(1, hidden_states.size(1) - 2, 1, 1, 1)
hidden_states = torch.cat([hidden_states, last_channel], dim=1)
print(f" output: {hidden_states.shape}, {hidden_states[0,0,:3,:3,:3]}")
return hidden_states
@@ -794,7 +992,9 @@ class LTXVideoDecoder3d(nn.Module):
# timestep embedding
self.time_embedder = None
self.scale_shift_table = None
self.timestep_scale_multiplier = None
if timestep_conditioning:
self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0, dtype=torch.float32))
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0)
self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5)
@@ -803,6 +1003,9 @@ class LTXVideoDecoder3d(nn.Module):
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
hidden_states = self.conv_in(hidden_states)
if self.timestep_scale_multiplier is not None:
temb = temb * self.timestep_scale_multiplier
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb)
@@ -891,12 +1094,19 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
out_channels: int = 3,
latent_channels: int = 128,
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
down_block_types: Tuple[str, ...] = (
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
),
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
decoder_layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False, False),
downsample_type: Tuple[str, ...] = ("conv", "conv", "conv", "conv"),
upsample_residual: Tuple[bool, ...] = (False, False, False, False),
upsample_factor: Tuple[int, ...] = (1, 1, 1, 1),
timestep_conditioning: bool = False,
@@ -906,6 +1116,8 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
scaling_factor: float = 1.0,
encoder_causal: bool = True,
decoder_causal: bool = False,
spatial_compression_ratio: int = None,
temporal_compression_ratio: int = None,
) -> None:
super().__init__()
@@ -913,8 +1125,10 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
in_channels=in_channels,
out_channels=latent_channels,
block_out_channels=block_out_channels,
down_block_types=down_block_types,
spatio_temporal_scaling=spatio_temporal_scaling,
layers_per_block=layers_per_block,
downsample_type=downsample_type,
patch_size=patch_size,
patch_size_t=patch_size_t,
resnet_norm_eps=resnet_norm_eps,
@@ -941,8 +1155,9 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.register_buffer("latents_mean", latents_mean, persistent=True)
self.register_buffer("latents_std", latents_std, persistent=True)
self.spatial_compression_ratio = patch_size * 2 ** sum(spatio_temporal_scaling)
self.temporal_compression_ratio = patch_size_t * 2 ** sum(spatio_temporal_scaling)
self.spatial_compression_ratio = patch_size * 2 ** sum(spatio_temporal_scaling) if spatial_compression_ratio is None else spatial_compression_ratio
self.temporal_compression_ratio = patch_size_t * 2 ** sum(spatio_temporal_scaling) if temporal_compression_ratio is None else temporal_compression_ratio
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
# to perform decoding of a single video latent at a time.

View File

@@ -14,7 +14,7 @@
# limitations under the License.
import math
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -22,7 +22,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention
@@ -102,6 +102,7 @@ class LTXVideoRotaryPosEmbed(nn.Module):
patch_size: int = 1,
patch_size_t: int = 1,
theta: float = 10000.0,
_causal_rope_fix: bool = False,
) -> None:
super().__init__()
@@ -112,6 +113,7 @@ class LTXVideoRotaryPosEmbed(nn.Module):
self.patch_size = patch_size
self.patch_size_t = patch_size_t
self.theta = theta
self._causal_rope_fix = _causal_rope_fix
def forward(
self,
@@ -119,6 +121,7 @@ class LTXVideoRotaryPosEmbed(nn.Module):
num_frames: int,
height: int,
width: int,
frame_rate: Optional[int] = None,
rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = hidden_states.size(0)
@@ -132,9 +135,24 @@ class LTXVideoRotaryPosEmbed(nn.Module):
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
if rope_interpolation_scale is not None:
grid[:, 0:1] = grid[:, 0:1] * rope_interpolation_scale[0] * self.patch_size_t / self.base_num_frames
grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1] * self.patch_size / self.base_height
grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2] * self.patch_size / self.base_width
if isinstance(rope_interpolation_scale, tuple):
# This will be deprecated in v0.34.0
grid[:, 0:1] = grid[:, 0:1] * rope_interpolation_scale[0] * self.patch_size_t / self.base_num_frames
grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1] * self.patch_size / self.base_height
grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2] * self.patch_size / self.base_width
else:
if not self._causal_rope_fix:
grid[:, 0:1] = (
grid[:, 0:1] * rope_interpolation_scale[0:1] * self.patch_size_t / self.base_num_frames
)
else:
grid[:, 0:1] = (
((grid[:, 0:1] - 1) * rope_interpolation_scale[0:1] + 1 / frame_rate).clamp(min=0)
* self.patch_size_t
/ self.base_num_frames
)
grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1:2] * self.patch_size / self.base_height
grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2:3] * self.patch_size / self.base_width
grid = grid.flatten(2, 4).transpose(1, 2)
@@ -315,6 +333,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
caption_channels: int = 4096,
attention_bias: bool = True,
attention_out_bias: bool = True,
_causal_rope_fix: bool = False,
) -> None:
super().__init__()
@@ -336,6 +355,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
patch_size=patch_size,
patch_size_t=patch_size_t,
theta=10000.0,
_causal_rope_fix=_causal_rope_fix,
)
self.transformer_blocks = nn.ModuleList(
@@ -370,7 +390,8 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
num_frames: int,
height: int,
width: int,
rope_interpolation_scale: Optional[Tuple[float, float, float]] = None,
frame_rate: int,
rope_interpolation_scale: Optional[Union[Tuple[float, float, float], torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> torch.Tensor:
@@ -389,7 +410,11 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale)
if not isinstance(rope_interpolation_scale, torch.Tensor):
msg = "Passing a tuple for `rope_interpolation_scale` is deprecated and will be removed in v0.34.0."
deprecate("rope_interpolation_scale", "0.34.0", msg)
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, frame_rate, rope_interpolation_scale)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:

View File

@@ -694,9 +694,8 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
self._num_timesteps = len(timesteps)
# 6. Prepare micro-conditions
latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio
rope_interpolation_scale = (
1 / latent_frame_rate,
self.vae_temporal_compression_ratio / frame_rate,
self.vae_spatial_compression_ratio,
self.vae_spatial_compression_ratio,
)

File diff suppressed because it is too large Load Diff

View File

@@ -764,9 +764,8 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
self._num_timesteps = len(timesteps)
# 6. Prepare micro-conditions
latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio
rope_interpolation_scale = (
1 / latent_frame_rate,
self.vae_temporal_compression_ratio / frame_rate,
self.vae_spatial_compression_ratio,
self.vae_spatial_compression_ratio,
)