mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-14 06:45:22 +08:00
Merge branch 'main' into wan-layerwise-upcasting-tests
This commit is contained in:
@@ -24,6 +24,179 @@ You can find all the original LTX-Video checkpoints under the [Lightricks](https
|
||||
|
||||
The original codebase for LTX-2 can be found [here](https://github.com/Lightricks/LTX-2).
|
||||
|
||||
## Two-stages Generation
|
||||
Recommended pipeline to achieve production quality generation, this pipeline is composed of two stages:
|
||||
|
||||
- Stage 1: Generate a video at the target resolution using diffusion sampling with classifier-free guidance (CFG). This stage produces a coherent low-noise video sequence that respects the text/image conditioning.
|
||||
- Stage 2: Upsample the Stage 1 output by 2 and refine details using a distilled LoRA model to improve fidelity and visual quality. Stage 2 may apply lighter CFG to preserve the structure from Stage 1 while enhancing texture and sharpness.
|
||||
|
||||
Sample usage of text-to-video two stages pipeline
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.pipelines.ltx2 import LTX2Pipeline, LTX2LatentUpsamplePipeline
|
||||
from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
|
||||
from diffusers.pipelines.ltx2.utils import STAGE_2_DISTILLED_SIGMA_VALUES
|
||||
from diffusers.pipelines.ltx2.export_utils import encode_video
|
||||
|
||||
device = "cuda:0"
|
||||
width = 768
|
||||
height = 512
|
||||
|
||||
pipe = LTX2Pipeline.from_pretrained(
|
||||
"Lightricks/LTX-2", torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.enable_sequential_cpu_offload(device=device)
|
||||
|
||||
prompt = "A beautiful sunset over the ocean"
|
||||
negative_prompt = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static."
|
||||
|
||||
# Stage 1 default (non-distilled) inference
|
||||
frame_rate = 24.0
|
||||
video_latent, audio_latent = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
width=width,
|
||||
height=height,
|
||||
num_frames=121,
|
||||
frame_rate=frame_rate,
|
||||
num_inference_steps=40,
|
||||
sigmas=None,
|
||||
guidance_scale=4.0,
|
||||
output_type="latent",
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(
|
||||
"Lightricks/LTX-2",
|
||||
subfolder="latent_upsampler",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler)
|
||||
upsample_pipe.enable_model_cpu_offload(device=device)
|
||||
upscaled_video_latent = upsample_pipe(
|
||||
latents=video_latent,
|
||||
output_type="latent",
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# Load Stage 2 distilled LoRA
|
||||
pipe.load_lora_weights(
|
||||
"Lightricks/LTX-2", adapter_name="stage_2_distilled", weight_name="ltx-2-19b-distilled-lora-384.safetensors"
|
||||
)
|
||||
pipe.set_adapters("stage_2_distilled", 1.0)
|
||||
# VAE tiling is usually necessary to avoid OOM error when VAE decoding
|
||||
pipe.vae.enable_tiling()
|
||||
# Change scheduler to use Stage 2 distilled sigmas as is
|
||||
new_scheduler = FlowMatchEulerDiscreteScheduler.from_config(
|
||||
pipe.scheduler.config, use_dynamic_shifting=False, shift_terminal=None
|
||||
)
|
||||
pipe.scheduler = new_scheduler
|
||||
# Stage 2 inference with distilled LoRA and sigmas
|
||||
video, audio = pipe(
|
||||
latents=upscaled_video_latent,
|
||||
audio_latents=audio_latent,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=3,
|
||||
noise_scale=STAGE_2_DISTILLED_SIGMA_VALUES[0], # renoise with first sigma value https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py#L218
|
||||
sigmas=STAGE_2_DISTILLED_SIGMA_VALUES,
|
||||
guidance_scale=1.0,
|
||||
output_type="np",
|
||||
return_dict=False,
|
||||
)
|
||||
video = (video * 255).round().astype("uint8")
|
||||
video = torch.from_numpy(video)
|
||||
|
||||
encode_video(
|
||||
video[0],
|
||||
fps=frame_rate,
|
||||
audio=audio[0].float().cpu(),
|
||||
audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
|
||||
output_path="ltx2_lora_distilled_sample.mp4",
|
||||
)
|
||||
```
|
||||
|
||||
## Distilled checkpoint generation
|
||||
Fastest two-stages generation pipeline using a distilled checkpoint.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers.pipelines.ltx2 import LTX2Pipeline, LTX2LatentUpsamplePipeline
|
||||
from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
|
||||
from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES
|
||||
from diffusers.pipelines.ltx2.export_utils import encode_video
|
||||
|
||||
device = "cuda"
|
||||
width = 768
|
||||
height = 512
|
||||
random_seed = 42
|
||||
generator = torch.Generator(device).manual_seed(random_seed)
|
||||
model_path = "rootonchair/LTX-2-19b-distilled"
|
||||
|
||||
pipe = LTX2Pipeline.from_pretrained(
|
||||
model_path, torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.enable_sequential_cpu_offload(device=device)
|
||||
|
||||
prompt = "A beautiful sunset over the ocean"
|
||||
negative_prompt = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static."
|
||||
|
||||
frame_rate = 24.0
|
||||
video_latent, audio_latent = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
width=width,
|
||||
height=height,
|
||||
num_frames=121,
|
||||
frame_rate=frame_rate,
|
||||
num_inference_steps=8,
|
||||
sigmas=DISTILLED_SIGMA_VALUES,
|
||||
guidance_scale=1.0,
|
||||
generator=generator,
|
||||
output_type="latent",
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(
|
||||
model_path,
|
||||
subfolder="latent_upsampler",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler)
|
||||
upsample_pipe.enable_model_cpu_offload(device=device)
|
||||
upscaled_video_latent = upsample_pipe(
|
||||
latents=video_latent,
|
||||
output_type="latent",
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
video, audio = pipe(
|
||||
latents=upscaled_video_latent,
|
||||
audio_latents=audio_latent,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=3,
|
||||
noise_scale=STAGE_2_DISTILLED_SIGMA_VALUES[0], # renoise with first sigma value https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-pipelines/src/ltx_pipelines/distilled.py#L178
|
||||
sigmas=STAGE_2_DISTILLED_SIGMA_VALUES,
|
||||
generator=generator,
|
||||
guidance_scale=1.0,
|
||||
output_type="np",
|
||||
return_dict=False,
|
||||
)
|
||||
video = (video * 255).round().astype("uint8")
|
||||
video = torch.from_numpy(video)
|
||||
|
||||
encode_video(
|
||||
video[0],
|
||||
fps=frame_rate,
|
||||
audio=audio[0].float().cpu(),
|
||||
audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
|
||||
output_path="ltx2_distilled_sample.mp4",
|
||||
)
|
||||
```
|
||||
|
||||
## LTX2Pipeline
|
||||
|
||||
[[autodoc]] LTX2Pipeline
|
||||
|
||||
@@ -83,25 +83,6 @@ Refer to this [table](https://github.com/huggingface/diffusers/pull/10009#issue-
|
||||
> [!TIP]
|
||||
> The FP8 post-training quantization schemes in torchao are effective for GPUs with compute capability of at least 8.9 (RTX-4090, Hopper, etc.). FP8 often provides the best speed, memory, and quality trade-off when generating images and videos. We recommend combining FP8 and torch.compile if your GPU is compatible.
|
||||
|
||||
## autoquant
|
||||
|
||||
torchao provides [autoquant](https://docs.pytorch.org/ao/stable/generated/torchao.quantization.autoquant.html#torchao.quantization.autoquant) an automatic quantization API. Autoquantization chooses the best quantization strategy by comparing the performance of each strategy on chosen input types and shapes. This is only supported in Diffusers for individual models at the moment.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from torchao.quantization import autoquant
|
||||
|
||||
# Load the pipeline
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-schnell",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda"
|
||||
)
|
||||
|
||||
transformer = autoquant(pipeline.transformer)
|
||||
```
|
||||
|
||||
## Supported quantization types
|
||||
|
||||
torchao supports weight-only quantization and weight and dynamic-activation quantization for int8, float3-float8, and uint1-uint7.
|
||||
|
||||
@@ -63,6 +63,8 @@ LTX_2_0_VIDEO_VAE_RENAME_DICT = {
|
||||
"up_blocks.4": "up_blocks.1",
|
||||
"up_blocks.5": "up_blocks.2.upsamplers.0",
|
||||
"up_blocks.6": "up_blocks.2",
|
||||
"last_time_embedder": "time_embedder",
|
||||
"last_scale_shift_table": "scale_shift_table",
|
||||
# Common
|
||||
# For all 3D ResNets
|
||||
"res_blocks": "resnets",
|
||||
@@ -372,7 +374,9 @@ def convert_ltx2_connectors(original_state_dict: Dict[str, Any], version: str) -
|
||||
return connectors
|
||||
|
||||
|
||||
def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
def get_ltx2_video_vae_config(
|
||||
version: str, timestep_conditioning: bool = False
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "test":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/dummy-ltx2",
|
||||
@@ -396,7 +400,7 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A
|
||||
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": False,
|
||||
"timestep_conditioning": timestep_conditioning,
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
@@ -433,7 +437,7 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A
|
||||
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": False,
|
||||
"timestep_conditioning": timestep_conditioning,
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
@@ -450,8 +454,10 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version)
|
||||
def convert_ltx2_video_vae(
|
||||
original_state_dict: Dict[str, Any], version: str, timestep_conditioning: bool
|
||||
) -> Dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version, timestep_conditioning)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
with init_empty_weights():
|
||||
@@ -659,10 +665,15 @@ def get_model_state_dict_from_combined_ckpt(combined_ckpt: Dict[str, Any], prefi
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
def none_or_str(value: str):
|
||||
if isinstance(value, str) and value.lower() == "none":
|
||||
return None
|
||||
return value
|
||||
|
||||
parser.add_argument(
|
||||
"--original_state_dict_repo_id",
|
||||
default="Lightricks/LTX-2",
|
||||
type=str,
|
||||
type=none_or_str,
|
||||
help="HF Hub repo id with LTX 2.0 checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -682,7 +693,7 @@ def get_args():
|
||||
parser.add_argument(
|
||||
"--combined_filename",
|
||||
default="ltx-2-19b-dev.safetensors",
|
||||
type=str,
|
||||
type=none_or_str,
|
||||
help="Filename for combined checkpoint with all LTX 2.0 models (VAE, DiT, etc.)",
|
||||
)
|
||||
parser.add_argument("--vae_prefix", default="vae.", type=str)
|
||||
@@ -701,22 +712,25 @@ def get_args():
|
||||
parser.add_argument(
|
||||
"--text_encoder_model_id",
|
||||
default="google/gemma-3-12b-it-qat-q4_0-unquantized",
|
||||
type=str,
|
||||
type=none_or_str,
|
||||
help="HF Hub id for the LTX 2.0 base text encoder model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_id",
|
||||
default="google/gemma-3-12b-it-qat-q4_0-unquantized",
|
||||
type=str,
|
||||
type=none_or_str,
|
||||
help="HF Hub id for the LTX 2.0 text tokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--latent_upsampler_filename",
|
||||
default="ltx-2-spatial-upscaler-x2-1.0.safetensors",
|
||||
type=str,
|
||||
type=none_or_str,
|
||||
help="Latent upsampler filename",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--timestep_conditioning", action="store_true", help="Whether to add timestep condition to the video VAE model"
|
||||
)
|
||||
parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model")
|
||||
parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model")
|
||||
parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model")
|
||||
@@ -786,7 +800,9 @@ def main(args):
|
||||
original_vae_ckpt = load_hub_or_local_checkpoint(filename=args.vae_filename)
|
||||
elif combined_ckpt is not None:
|
||||
original_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vae_prefix)
|
||||
vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version)
|
||||
vae = convert_ltx2_video_vae(
|
||||
original_vae_ckpt, version=args.version, timestep_conditioning=args.timestep_conditioning
|
||||
)
|
||||
if not args.full_pipeline and not args.upsample_pipeline:
|
||||
vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae"))
|
||||
|
||||
|
||||
@@ -743,8 +743,8 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
|
||||
# Per-channel statistics for normalizing and denormalizing the latent representation. This statics is computed over
|
||||
# the entire dataset and stored in model's checkpoint under AudioVAE state_dict
|
||||
latents_std = torch.zeros((base_channels,))
|
||||
latents_mean = torch.ones((base_channels,))
|
||||
latents_std = torch.ones((base_channels,))
|
||||
latents_mean = torch.zeros((base_channels,))
|
||||
self.register_buffer("latents_mean", latents_mean, persistent=True)
|
||||
self.register_buffer("latents_std", latents_std, persistent=True)
|
||||
|
||||
|
||||
@@ -584,6 +584,17 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2_image2video.LTX2ImageToVideoPipeline._normalize_latents
|
||||
def _normalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
# Normalize latents across the channel dimension [B, C, F, H, W]
|
||||
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = (latents - latents_mean) * scaling_factor / latents_std
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def _denormalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
@@ -594,12 +605,26 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
latents = latents * latents_std / scaling_factor + latents_mean
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def _normalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor):
|
||||
latents_mean = latents_mean.to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.to(latents.device, latents.dtype)
|
||||
return (latents - latents_mean) / latents_std
|
||||
|
||||
@staticmethod
|
||||
def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor):
|
||||
latents_mean = latents_mean.to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.to(latents.device, latents.dtype)
|
||||
return (latents * latents_std) + latents_mean
|
||||
|
||||
@staticmethod
|
||||
def _create_noised_state(
|
||||
latents: torch.Tensor, noise_scale: Union[float, torch.Tensor], generator: Optional[torch.Generator] = None
|
||||
):
|
||||
noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype)
|
||||
noised_latents = noise_scale * noise + (1 - noise_scale) * latents
|
||||
return noised_latents
|
||||
|
||||
@staticmethod
|
||||
def _pack_audio_latents(
|
||||
latents: torch.Tensor, patch_size: Optional[int] = None, patch_size_t: Optional[int] = None
|
||||
@@ -647,12 +672,26 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
height: int = 512,
|
||||
width: int = 768,
|
||||
num_frames: int = 121,
|
||||
noise_scale: float = 0.0,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if latents is not None:
|
||||
if latents.ndim == 5:
|
||||
latents = self._normalize_latents(
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
# latents are of shape [B, C, F, H, W], need to be packed
|
||||
latents = self._pack_latents(
|
||||
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
|
||||
)
|
||||
if latents.ndim != 3:
|
||||
raise ValueError(
|
||||
f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is [batch_size, num_seq, num_features]."
|
||||
)
|
||||
latents = self._create_noised_state(latents, noise_scale, generator)
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
height = height // self.vae_spatial_compression_ratio
|
||||
@@ -677,29 +716,30 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
num_channels_latents: int = 8,
|
||||
audio_latent_length: int = 1, # 1 is just a dummy value
|
||||
num_mel_bins: int = 64,
|
||||
num_frames: int = 121,
|
||||
frame_rate: float = 25.0,
|
||||
sampling_rate: int = 16000,
|
||||
hop_length: int = 160,
|
||||
noise_scale: float = 0.0,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
duration_s = num_frames / frame_rate
|
||||
latents_per_second = (
|
||||
float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio)
|
||||
)
|
||||
latent_length = round(duration_s * latents_per_second)
|
||||
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype), latent_length
|
||||
if latents.ndim == 4:
|
||||
# latents are of shape [B, C, L, M], need to be packed
|
||||
latents = self._pack_audio_latents(latents)
|
||||
if latents.ndim != 3:
|
||||
raise ValueError(
|
||||
f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is [batch_size, num_seq, num_features]."
|
||||
)
|
||||
latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std)
|
||||
latents = self._create_noised_state(latents, noise_scale, generator)
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
# TODO: confirm whether this logic is correct
|
||||
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
|
||||
|
||||
shape = (batch_size, num_channels_latents, latent_length, latent_mel_bins)
|
||||
shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
@@ -709,7 +749,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = self._pack_audio_latents(latents)
|
||||
return latents, latent_length
|
||||
return latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
@@ -750,9 +790,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
num_frames: int = 121,
|
||||
frame_rate: float = 24.0,
|
||||
num_inference_steps: int = 40,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 4.0,
|
||||
guidance_rescale: float = 0.0,
|
||||
noise_scale: float = 0.0,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
@@ -788,6 +830,10 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
num_inference_steps (`int`, *optional*, defaults to 40):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
@@ -804,6 +850,9 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
[Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
|
||||
using zero terminal SNR.
|
||||
noise_scale (`float`, *optional*, defaults to `0.0`):
|
||||
The interpolation factor between random noise and denoised latents at each timestep. Applying noise to
|
||||
the `latents` and `audio_latents` before continue denoising.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of videos to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
@@ -922,6 +971,21 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
|
||||
latent_height = height // self.vae_spatial_compression_ratio
|
||||
latent_width = width // self.vae_spatial_compression_ratio
|
||||
if latents is not None:
|
||||
if latents.ndim == 5:
|
||||
logger.info(
|
||||
"Got latents of shape [batch_size, latent_dim, latent_frames, latent_height, latent_width], `latent_num_frames`, `latent_height`, `latent_width` will be inferred."
|
||||
)
|
||||
_, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W]
|
||||
elif latents.ndim == 3:
|
||||
logger.warning(
|
||||
f"You have supplied packed `latents` of shape {latents.shape}, so the latent dims cannot be"
|
||||
f" inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, latent_dim, latent_frames, latent_height, latent_width]."
|
||||
)
|
||||
video_sequence_length = latent_num_frames * latent_height * latent_width
|
||||
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
@@ -931,26 +995,45 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
noise_scale,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
duration_s = num_frames / frame_rate
|
||||
audio_latents_per_second = (
|
||||
self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio)
|
||||
)
|
||||
audio_num_frames = round(duration_s * audio_latents_per_second)
|
||||
if audio_latents is not None:
|
||||
if audio_latents.ndim == 4:
|
||||
logger.info(
|
||||
"Got audio_latents of shape [batch_size, num_channels, audio_length, mel_bins], `audio_num_frames` will be inferred."
|
||||
)
|
||||
_, _, audio_num_frames, _ = audio_latents.shape # [B, C, L, M]
|
||||
elif audio_latents.ndim == 3:
|
||||
logger.warning(
|
||||
f"You have supplied packed `audio_latents` of shape {audio_latents.shape}, so the latent dims"
|
||||
f" cannot be inferred. Make sure the supplied `num_frames` and `frame_rate` are correct."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Provided `audio_latents` tensor has shape {audio_latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, num_channels, audio_length, mel_bins]."
|
||||
)
|
||||
|
||||
num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64
|
||||
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
|
||||
|
||||
num_channels_latents_audio = (
|
||||
self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8
|
||||
)
|
||||
audio_latents, audio_num_frames = self.prepare_audio_latents(
|
||||
audio_latents = self.prepare_audio_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents=num_channels_latents_audio,
|
||||
audio_latent_length=audio_num_frames,
|
||||
num_mel_bins=num_mel_bins,
|
||||
num_frames=num_frames, # Video frames, audio frames will be calculated from this
|
||||
frame_rate=frame_rate,
|
||||
sampling_rate=self.audio_sampling_rate,
|
||||
hop_length=self.audio_hop_length,
|
||||
noise_scale=noise_scale,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
generator=generator,
|
||||
@@ -958,7 +1041,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||
mu = calculate_shift(
|
||||
video_sequence_length,
|
||||
self.scheduler.config.get("base_image_seq_len", 1024),
|
||||
|
||||
@@ -614,6 +614,15 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
latents = latents * latents_std / scaling_factor + latents_mean
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._create_noised_state
|
||||
def _create_noised_state(
|
||||
latents: torch.Tensor, noise_scale: Union[float, torch.Tensor], generator: Optional[torch.Generator] = None
|
||||
):
|
||||
noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype)
|
||||
noised_latents = noise_scale * noise + (1 - noise_scale) * latents
|
||||
return noised_latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_audio_latents
|
||||
def _pack_audio_latents(
|
||||
@@ -656,6 +665,13 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
latents = latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._normalize_audio_latents
|
||||
def _normalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor):
|
||||
latents_mean = latents_mean.to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.to(latents.device, latents.dtype)
|
||||
return (latents - latents_mean) / latents_std
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_audio_latents
|
||||
def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor):
|
||||
@@ -671,6 +687,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
height: int = 512,
|
||||
width: int = 704,
|
||||
num_frames: int = 161,
|
||||
noise_scale: float = 0.0,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
@@ -686,6 +703,15 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
if latents is not None:
|
||||
conditioning_mask = latents.new_zeros(mask_shape)
|
||||
conditioning_mask[:, :, 0] = 1.0
|
||||
if latents.ndim == 5:
|
||||
latents = self._normalize_latents(
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
latents = self._create_noised_state(latents, noise_scale * (1 - conditioning_mask), generator)
|
||||
# latents are of shape [B, C, F, H, W], need to be packed
|
||||
latents = self._pack_latents(
|
||||
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
|
||||
)
|
||||
conditioning_mask = self._pack_latents(
|
||||
conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
|
||||
).squeeze(-1)
|
||||
@@ -737,29 +763,30 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
num_channels_latents: int = 8,
|
||||
audio_latent_length: int = 1, # 1 is just a dummy value
|
||||
num_mel_bins: int = 64,
|
||||
num_frames: int = 121,
|
||||
frame_rate: float = 25.0,
|
||||
sampling_rate: int = 16000,
|
||||
hop_length: int = 160,
|
||||
noise_scale: float = 0.0,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
duration_s = num_frames / frame_rate
|
||||
latents_per_second = (
|
||||
float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio)
|
||||
)
|
||||
latent_length = round(duration_s * latents_per_second)
|
||||
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype), latent_length
|
||||
if latents.ndim == 4:
|
||||
# latents are of shape [B, C, L, M], need to be packed
|
||||
latents = self._pack_audio_latents(latents)
|
||||
if latents.ndim != 3:
|
||||
raise ValueError(
|
||||
f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is [batch_size, num_seq, num_features]."
|
||||
)
|
||||
latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std)
|
||||
latents = self._create_noised_state(latents, noise_scale, generator)
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
# TODO: confirm whether this logic is correct
|
||||
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
|
||||
|
||||
shape = (batch_size, num_channels_latents, latent_length, latent_mel_bins)
|
||||
shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
@@ -769,7 +796,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = self._pack_audio_latents(latents)
|
||||
return latents, latent_length
|
||||
return latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
@@ -811,9 +838,11 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
num_frames: int = 121,
|
||||
frame_rate: float = 24.0,
|
||||
num_inference_steps: int = 40,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 4.0,
|
||||
guidance_rescale: float = 0.0,
|
||||
noise_scale: float = 0.0,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
@@ -851,6 +880,10 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
num_inference_steps (`int`, *optional*, defaults to 40):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
@@ -867,6 +900,9 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
[Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
|
||||
using zero terminal SNR.
|
||||
noise_scale (`float`, *optional*, defaults to `0.0`):
|
||||
The interpolation factor between random noise and denoised latents at each timestep. Applying noise to
|
||||
the `latents` and `audio_latents` before continue denoising.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of videos to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
@@ -982,6 +1018,26 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
|
||||
latent_height = height // self.vae_spatial_compression_ratio
|
||||
latent_width = width // self.vae_spatial_compression_ratio
|
||||
if latents is not None:
|
||||
if latents.ndim == 5:
|
||||
logger.info(
|
||||
"Got latents of shape [batch_size, latent_dim, latent_frames, latent_height, latent_width], `latent_num_frames`, `latent_height`, `latent_width` will be inferred."
|
||||
)
|
||||
_, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W]
|
||||
elif latents.ndim == 3:
|
||||
logger.warning(
|
||||
f"You have supplied packed `latents` of shape {latents.shape}, so the latent dims cannot be"
|
||||
f" inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, latent_dim, latent_frames, latent_height, latent_width]."
|
||||
)
|
||||
video_sequence_length = latent_num_frames * latent_height * latent_width
|
||||
|
||||
if latents is None:
|
||||
image = self.video_processor.preprocess(image, height=height, width=width)
|
||||
image = image.to(device=device, dtype=prompt_embeds.dtype)
|
||||
@@ -994,6 +1050,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
noise_scale,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
@@ -1002,20 +1059,38 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
if self.do_classifier_free_guidance:
|
||||
conditioning_mask = torch.cat([conditioning_mask, conditioning_mask])
|
||||
|
||||
duration_s = num_frames / frame_rate
|
||||
audio_latents_per_second = (
|
||||
self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio)
|
||||
)
|
||||
audio_num_frames = round(duration_s * audio_latents_per_second)
|
||||
if audio_latents is not None:
|
||||
if audio_latents.ndim == 4:
|
||||
logger.info(
|
||||
"Got audio_latents of shape [batch_size, num_channels, audio_length, mel_bins], `audio_num_frames` will be inferred."
|
||||
)
|
||||
_, _, audio_num_frames, _ = audio_latents.shape # [B, C, L, M]
|
||||
elif audio_latents.ndim == 3:
|
||||
logger.warning(
|
||||
f"You have supplied packed `audio_latents` of shape {audio_latents.shape}, so the latent dims"
|
||||
f" cannot be inferred. Make sure the supplied `num_frames` and `frame_rate` are correct."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Provided `audio_latents` tensor has shape {audio_latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, num_channels, audio_length, mel_bins]."
|
||||
)
|
||||
|
||||
num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64
|
||||
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
|
||||
|
||||
num_channels_latents_audio = (
|
||||
self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8
|
||||
)
|
||||
audio_latents, audio_num_frames = self.prepare_audio_latents(
|
||||
audio_latents = self.prepare_audio_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents=num_channels_latents_audio,
|
||||
audio_latent_length=audio_num_frames,
|
||||
num_mel_bins=num_mel_bins,
|
||||
num_frames=num_frames, # Video frames, audio frames will be calculated from this
|
||||
frame_rate=frame_rate,
|
||||
sampling_rate=self.audio_sampling_rate,
|
||||
hop_length=self.audio_hop_length,
|
||||
noise_scale=noise_scale,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
generator=generator,
|
||||
@@ -1023,12 +1098,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
|
||||
latent_height = height // self.vae_spatial_compression_ratio
|
||||
latent_width = width // self.vae_spatial_compression_ratio
|
||||
video_sequence_length = latent_num_frames * latent_height * latent_width
|
||||
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||
mu = calculate_shift(
|
||||
video_sequence_length,
|
||||
self.scheduler.config.get("base_image_seq_len", 1024),
|
||||
|
||||
@@ -228,17 +228,6 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline):
|
||||
filtered = latents * scales
|
||||
return filtered
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2_image2video.LTX2ImageToVideoPipeline._normalize_latents
|
||||
def _normalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
# Normalize latents across the channel dimension [B, C, F, H, W]
|
||||
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = (latents - latents_mean) * scaling_factor / latents_std
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_latents
|
||||
def _denormalize_latents(
|
||||
@@ -408,9 +397,6 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline):
|
||||
latents = self.tone_map_latents(latents, tone_map_compression_ratio)
|
||||
|
||||
if output_type == "latent":
|
||||
latents = self._normalize_latents(
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
video = latents
|
||||
else:
|
||||
if not self.vae.config.timestep_conditioning:
|
||||
|
||||
6
src/diffusers/pipelines/ltx2/utils.py
Normal file
6
src/diffusers/pipelines/ltx2/utils.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# Pre-trained sigma values for distilled model are taken from
|
||||
# https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-pipelines/src/ltx_pipelines/utils/constants.py
|
||||
DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875]
|
||||
|
||||
# Reduced schedule for super-resolution stage 2 (subset of distilled values)
|
||||
STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875]
|
||||
@@ -496,6 +496,17 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
|
||||
num_frames = max(num_frames, 1)
|
||||
|
||||
h_multiple_of = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
|
||||
w_multiple_of = self.vae_scale_factor_spatial * self.transformer.config.patch_size[2]
|
||||
calc_height = height // h_multiple_of * h_multiple_of
|
||||
calc_width = width // w_multiple_of * w_multiple_of
|
||||
if height != calc_height or width != calc_width:
|
||||
logger.warning(
|
||||
f"`height` and `width` must be multiples of ({h_multiple_of}, {w_multiple_of}) for proper patchification. "
|
||||
f"Adjusting ({height}, {width}) -> ({calc_height}, {calc_width})."
|
||||
)
|
||||
height, width = calc_height, calc_width
|
||||
|
||||
if self.config.boundary_ratio is not None and guidance_scale_2 is None:
|
||||
guidance_scale_2 = guidance_scale
|
||||
|
||||
|
||||
@@ -637,6 +637,17 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
|
||||
num_frames = max(num_frames, 1)
|
||||
|
||||
h_multiple_of = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
|
||||
w_multiple_of = self.vae_scale_factor_spatial * self.transformer.config.patch_size[2]
|
||||
calc_height = height // h_multiple_of * h_multiple_of
|
||||
calc_width = width // w_multiple_of * w_multiple_of
|
||||
if height != calc_height or width != calc_width:
|
||||
logger.warning(
|
||||
f"`height` and `width` must be multiples of ({h_multiple_of}, {w_multiple_of}) for proper patchification. "
|
||||
f"Adjusting ({height}, {width}) -> ({calc_height}, {calc_width})."
|
||||
)
|
||||
height, width = calc_height, calc_width
|
||||
|
||||
if self.config.boundary_ratio is not None and guidance_scale_2 is None:
|
||||
guidance_scale_2 = guidance_scale
|
||||
|
||||
|
||||
@@ -623,7 +623,7 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
"""
|
||||
|
||||
if is_torchao_available():
|
||||
# TODO(aryan): Support autoquant and sparsify
|
||||
# TODO(aryan): Support sparsify
|
||||
from torchao.quantization import (
|
||||
float8_dynamic_activation_float8_weight,
|
||||
float8_static_activation_float8_weight,
|
||||
|
||||
@@ -344,7 +344,6 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
|
||||
from torchao.core.config import AOBaseConfig
|
||||
|
||||
quant_type = self.quantization_config.quant_type
|
||||
# For autoquant case, it will be treated in the string implementation below in map_to_target_dtype
|
||||
if isinstance(quant_type, AOBaseConfig):
|
||||
# Extract size digit using fuzzy match on the class name
|
||||
config_name = quant_type.__class__.__name__
|
||||
|
||||
@@ -105,7 +105,6 @@ def rescale_zero_terminal_snr(alphas_cumprod):
|
||||
"""
|
||||
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
|
||||
|
||||
|
||||
Args:
|
||||
betas (`torch.Tensor`):
|
||||
the betas that the scheduler is being initialized with.
|
||||
@@ -175,11 +174,14 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
|
||||
timestep_spacing (`str`, defaults to `"leading"`):
|
||||
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
||||
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||||
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. Choose from
|
||||
`leading`, `linspace` or `trailing`.
|
||||
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
||||
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
||||
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
||||
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
||||
snr_shift_scale (`float`, defaults to 3.0):
|
||||
Shift scale for SNR.
|
||||
"""
|
||||
|
||||
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||
@@ -191,15 +193,15 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.00085,
|
||||
beta_end: float = 0.0120,
|
||||
beta_schedule: str = "scaled_linear",
|
||||
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "scaled_linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
clip_sample: bool = True,
|
||||
set_alpha_to_one: bool = True,
|
||||
steps_offset: int = 0,
|
||||
prediction_type: str = "epsilon",
|
||||
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
|
||||
clip_sample_range: float = 1.0,
|
||||
sample_max_value: float = 1.0,
|
||||
timestep_spacing: str = "leading",
|
||||
timestep_spacing: Literal["leading", "linspace", "trailing"] = "leading",
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
snr_shift_scale: float = 3.0,
|
||||
):
|
||||
@@ -209,7 +211,15 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float64) ** 2
|
||||
self.betas = (
|
||||
torch.linspace(
|
||||
beta_start**0.5,
|
||||
beta_end**0.5,
|
||||
num_train_timesteps,
|
||||
dtype=torch.float64,
|
||||
)
|
||||
** 2
|
||||
)
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# Glide cosine schedule
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
@@ -266,13 +276,20 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return sample
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None` (the default), the timesteps are not
|
||||
moved.
|
||||
"""
|
||||
|
||||
if num_inference_steps > self.config.num_train_timesteps:
|
||||
@@ -311,7 +328,27 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||
|
||||
def get_variables(self, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back=None):
|
||||
def get_variables(
|
||||
self,
|
||||
alpha_prod_t: torch.Tensor,
|
||||
alpha_prod_t_prev: torch.Tensor,
|
||||
alpha_prod_t_back: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Compute the variables used for DPM-Solver++ (2M) referencing the original implementation.
|
||||
|
||||
Args:
|
||||
alpha_prod_t (`torch.Tensor`):
|
||||
The cumulative product of alphas at the current timestep.
|
||||
alpha_prod_t_prev (`torch.Tensor`):
|
||||
The cumulative product of alphas at the previous timestep.
|
||||
alpha_prod_t_back (`torch.Tensor`, *optional*):
|
||||
The cumulative product of alphas at the timestep before the previous timestep.
|
||||
|
||||
Returns:
|
||||
`tuple`:
|
||||
A tuple containing the variables `h`, `r`, `lamb`, `lamb_next`.
|
||||
"""
|
||||
lamb = ((alpha_prod_t / (1 - alpha_prod_t)) ** 0.5).log()
|
||||
lamb_next = ((alpha_prod_t_prev / (1 - alpha_prod_t_prev)) ** 0.5).log()
|
||||
h = lamb_next - lamb
|
||||
@@ -324,7 +361,36 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
return h, None, lamb, lamb_next
|
||||
|
||||
def get_mult(self, h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back):
|
||||
def get_mult(
|
||||
self,
|
||||
h: torch.Tensor,
|
||||
r: Optional[torch.Tensor],
|
||||
alpha_prod_t: torch.Tensor,
|
||||
alpha_prod_t_prev: torch.Tensor,
|
||||
alpha_prod_t_back: Optional[torch.Tensor] = None,
|
||||
) -> Union[
|
||||
Tuple[torch.Tensor, torch.Tensor],
|
||||
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
]:
|
||||
"""
|
||||
Compute the multipliers for the previous sample and the predicted original sample.
|
||||
|
||||
Args:
|
||||
h (`torch.Tensor`):
|
||||
The log-SNR difference.
|
||||
r (`torch.Tensor`):
|
||||
The ratio of log-SNR differences.
|
||||
alpha_prod_t (`torch.Tensor`):
|
||||
The cumulative product of alphas at the current timestep.
|
||||
alpha_prod_t_prev (`torch.Tensor`):
|
||||
The cumulative product of alphas at the previous timestep.
|
||||
alpha_prod_t_back (`torch.Tensor`, *optional*):
|
||||
The cumulative product of alphas at the timestep before the previous timestep.
|
||||
|
||||
Returns:
|
||||
`tuple`:
|
||||
A tuple containing the multipliers.
|
||||
"""
|
||||
mult1 = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5 * (-h).exp()
|
||||
mult2 = (-2 * h).expm1() * alpha_prod_t_prev**0.5
|
||||
|
||||
@@ -338,13 +404,13 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
old_pred_original_sample: torch.Tensor,
|
||||
old_pred_original_sample: Optional[torch.Tensor],
|
||||
timestep: int,
|
||||
timestep_back: int,
|
||||
sample: torch.Tensor,
|
||||
eta: float = 0.0,
|
||||
use_clipped_model_output: bool = False,
|
||||
generator=None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
variance_noise: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = False,
|
||||
) -> Union[DDIMSchedulerOutput, Tuple]:
|
||||
@@ -355,8 +421,12 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output (`torch.Tensor`):
|
||||
The direct output from learned diffusion model.
|
||||
timestep (`float`):
|
||||
old_pred_original_sample (`torch.Tensor`):
|
||||
The predicted original sample from the previous timestep.
|
||||
timestep (`int`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
timestep_back (`int`):
|
||||
The timestep to look back to.
|
||||
sample (`torch.Tensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
eta (`float`):
|
||||
@@ -436,7 +506,12 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
return prev_sample, pred_original_sample
|
||||
else:
|
||||
denoised_d = mult[2] * pred_original_sample - mult[3] * old_pred_original_sample
|
||||
noise = randn_tensor(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype)
|
||||
noise = randn_tensor(
|
||||
sample.shape,
|
||||
generator=generator,
|
||||
device=sample.device,
|
||||
dtype=sample.dtype,
|
||||
)
|
||||
x_advanced = mult[0] * sample - mult[1] * denoised_d + mult_noise * noise
|
||||
|
||||
prev_sample = x_advanced
|
||||
@@ -524,5 +599,5 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
||||
return velocity
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
@@ -222,7 +222,57 @@ class LTX2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
)
|
||||
expected_audio_slice = torch.tensor(
|
||||
[
|
||||
0.0236, 0.0499, 0.1230, 0.1094, 0.1713, 0.1044, 0.1729, 0.1009, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739
|
||||
0.0263, 0.0528, 0.1217, 0.1104, 0.1632, 0.1072, 0.1789, 0.0949, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
video = video.flatten()
|
||||
audio = audio.flatten()
|
||||
generated_video_slice = torch.cat([video[:8], video[-8:]])
|
||||
generated_audio_slice = torch.cat([audio[:8], audio[-8:]])
|
||||
|
||||
assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4)
|
||||
assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_two_stages_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["output_type"] = "latent"
|
||||
first_stage_output = pipe(**inputs)
|
||||
video_latent = first_stage_output.frames
|
||||
audio_latent = first_stage_output.audio
|
||||
|
||||
self.assertEqual(video_latent.shape, (1, 4, 3, 16, 16))
|
||||
self.assertEqual(audio_latent.shape, (1, 2, 5, 2))
|
||||
self.assertEqual(audio_latent.shape[1], components["vocoder"].config.out_channels)
|
||||
|
||||
inputs["latents"] = video_latent
|
||||
inputs["audio_latents"] = audio_latent
|
||||
inputs["output_type"] = "pt"
|
||||
second_stage_output = pipe(**inputs)
|
||||
video = second_stage_output.frames
|
||||
audio = second_stage_output.audio
|
||||
|
||||
self.assertEqual(video.shape, (1, 5, 3, 32, 32))
|
||||
self.assertEqual(audio.shape[0], 1)
|
||||
self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels)
|
||||
|
||||
# fmt: off
|
||||
expected_video_slice = torch.tensor(
|
||||
[
|
||||
0.5514, 0.5943, 0.4260, 0.5971, 0.4306, 0.6369, 0.3124, 0.6964, 0.5419, 0.2412, 0.3882, 0.4504, 0.1941, 0.3404, 0.6037, 0.2464
|
||||
]
|
||||
)
|
||||
expected_audio_slice = torch.tensor(
|
||||
[
|
||||
0.0252, 0.0526, 0.1211, 0.1119, 0.1638, 0.1042, 0.1776, 0.0948, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
@@ -224,7 +224,57 @@ class LTX2ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
)
|
||||
expected_audio_slice = torch.tensor(
|
||||
[
|
||||
0.0236, 0.0499, 0.1230, 0.1094, 0.1713, 0.1044, 0.1729, 0.1009, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739
|
||||
0.0294, 0.0498, 0.1269, 0.1135, 0.1639, 0.1116, 0.1730, 0.0931, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
video = video.flatten()
|
||||
audio = audio.flatten()
|
||||
generated_video_slice = torch.cat([video[:8], video[-8:]])
|
||||
generated_audio_slice = torch.cat([audio[:8], audio[-8:]])
|
||||
|
||||
assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4)
|
||||
assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_two_stages_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["output_type"] = "latent"
|
||||
first_stage_output = pipe(**inputs)
|
||||
video_latent = first_stage_output.frames
|
||||
audio_latent = first_stage_output.audio
|
||||
|
||||
self.assertEqual(video_latent.shape, (1, 4, 3, 16, 16))
|
||||
self.assertEqual(audio_latent.shape, (1, 2, 5, 2))
|
||||
self.assertEqual(audio_latent.shape[1], components["vocoder"].config.out_channels)
|
||||
|
||||
inputs["latents"] = video_latent
|
||||
inputs["audio_latents"] = audio_latent
|
||||
inputs["output_type"] = "pt"
|
||||
second_stage_output = pipe(**inputs)
|
||||
video = second_stage_output.frames
|
||||
audio = second_stage_output.audio
|
||||
|
||||
self.assertEqual(video.shape, (1, 5, 3, 32, 32))
|
||||
self.assertEqual(audio.shape[0], 1)
|
||||
self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels)
|
||||
|
||||
# fmt: off
|
||||
expected_video_slice = torch.tensor(
|
||||
[
|
||||
0.2665, 0.6915, 0.2939, 0.6767, 0.2552, 0.6215, 0.1765, 0.6248, 0.2800, 0.2356, 0.3480, 0.5395, 0.3190, 0.4128, 0.4784, 0.4086
|
||||
]
|
||||
)
|
||||
expected_audio_slice = torch.tensor(
|
||||
[
|
||||
0.0273, 0.0490, 0.1253, 0.1129, 0.1655, 0.1057, 0.1707, 0.0943, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
Reference in New Issue
Block a user