mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-08 13:34:27 +08:00
Compare commits
1 Commits
cached-lor
...
modular-lt
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
05c11a0be4 |
782
src/diffusers/modular_pipelines/ltx/before_denoise.py
Normal file
782
src/diffusers/modular_pipelines/ltx/before_denoise.py
Normal file
@@ -0,0 +1,782 @@
|
||||
# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Literal
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
from transformers import T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin
|
||||
from ...models.autoencoders import AutoencoderKLLTXVideo
|
||||
from ...models.transformers import LTXVideoTransformer3DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import LTXPipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXConditionPipeline, LTXVideoCondition
|
||||
>>> from diffusers.utils import export_to_video, load_video, load_image
|
||||
|
||||
>>> pipe = LTXConditionPipeline.from_pretrained("Lightricks/LTX-Video-0.9.5", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> # Load input image and video
|
||||
>>> video = load_video(
|
||||
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input-vid.mp4"
|
||||
... )
|
||||
>>> image = load_image(
|
||||
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input.jpg"
|
||||
... )
|
||||
|
||||
>>> # Create conditioning objects
|
||||
>>> condition1 = LTXVideoCondition(
|
||||
... image=image,
|
||||
... frame_index=0,
|
||||
... )
|
||||
>>> condition2 = LTXVideoCondition(
|
||||
... video=video,
|
||||
... frame_index=80,
|
||||
... )
|
||||
|
||||
>>> prompt = "The video depicts a long, straight highway stretching into the distance, flanked by metal guardrails. The road is divided into multiple lanes, with a few vehicles visible in the far distance. The surrounding landscape features dry, grassy fields on one side and rolling hills on the other. The sky is mostly clear with a few scattered clouds, suggesting a bright, sunny day. And then the camera switch to a winding mountain road covered in snow, with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the solitude and beauty of a winter drive through a mountainous region."
|
||||
>>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
|
||||
|
||||
>>> # Generate video
|
||||
>>> generator = torch.Generator("cuda").manual_seed(0)
|
||||
>>> # Text-only conditioning is also supported without the need to pass `conditions`
|
||||
>>> video = pipe(
|
||||
... conditions=[condition1, condition2],
|
||||
... prompt=prompt,
|
||||
... negative_prompt=negative_prompt,
|
||||
... width=768,
|
||||
... height=512,
|
||||
... num_frames=161,
|
||||
... num_inference_steps=40,
|
||||
... generator=generator,
|
||||
... ).frames[0]
|
||||
|
||||
>>> export_to_video(video, "output.mp4", fps=24)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class LTXVideoCondition:
|
||||
"""
|
||||
Defines a single frame-conditioning item for LTX Video - a single frame or a sequence of frames.
|
||||
|
||||
Attributes:
|
||||
condition (`Union[PIL.Image.Image, List[PIL.Image.Image]]`):
|
||||
Either a single image or a list of video frames to condition the video on.
|
||||
condition_type (`Literal["image", "video"]`):
|
||||
Explicitly indicates whether this is an image or video condition.
|
||||
frame_index (`int`):
|
||||
The frame index at which the image or video will conditionally effect the video generation.
|
||||
strength (`float`, defaults to `1.0`):
|
||||
The strength of the conditioning effect. A value of `1.0` means the conditioning effect is fully applied.
|
||||
"""
|
||||
|
||||
condition: Union[PIL.Image.Image, List[PIL.Image.Image]]
|
||||
condition_type: Literal["image", "video"]
|
||||
frame_index: int = 0
|
||||
strength: float = 1.0
|
||||
|
||||
@property
|
||||
def image(self):
|
||||
return self.condition if self.condition_type == "image" else None
|
||||
|
||||
@property
|
||||
def video(self):
|
||||
return self.condition if self.condition_type == "video" else None
|
||||
|
||||
|
||||
# from LTX-Video/ltx_video/schedulers/rf.py
|
||||
def linear_quadratic_schedule(num_steps, threshold_noise=0.025, linear_steps=None):
|
||||
if linear_steps is None:
|
||||
linear_steps = num_steps // 2
|
||||
if num_steps < 2:
|
||||
return torch.tensor([1.0])
|
||||
linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
|
||||
threshold_noise_step_diff = linear_steps - threshold_noise * num_steps
|
||||
quadratic_steps = num_steps - linear_steps
|
||||
quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2)
|
||||
linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2)
|
||||
const = quadratic_coef * (linear_steps**2)
|
||||
quadratic_sigma_schedule = [
|
||||
quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_steps)
|
||||
]
|
||||
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0]
|
||||
sigma_schedule = [1.0 - x for x in sigma_schedule]
|
||||
return torch.tensor(sigma_schedule[:-1])
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
max_seq_len: int = 4096,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
r"""
|
||||
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
|
||||
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
|
||||
Args:
|
||||
noise_cfg (`torch.Tensor`):
|
||||
The predicted noise tensor for the guided diffusion process.
|
||||
noise_pred_text (`torch.Tensor`):
|
||||
The predicted noise tensor for the text-guided diffusion process.
|
||||
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
||||
A rescale factor applied to the noise predictions.
|
||||
|
||||
Returns:
|
||||
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
|
||||
"""
|
||||
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
||||
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
||||
# rescale the results from guidance (fixes overexposure)
|
||||
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
||||
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
||||
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
||||
return noise_cfg
|
||||
|
||||
|
||||
def get_t5_prompt_embeds(
|
||||
text_encoder,
|
||||
tokenizer,
|
||||
device,
|
||||
dtype,
|
||||
prompt: Union[str, List[str]],
|
||||
repeat_per_prompt: int = 1,
|
||||
max_sequence_length: int = 256,
|
||||
return_attention_mask: bool = False,
|
||||
):
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
|
||||
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0]
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
if return_attention_mask:
|
||||
prompt_attention_mask = text_inputs.attention_mask
|
||||
prompt_attention_mask = prompt_attention_mask.bool().to(device)
|
||||
|
||||
prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
|
||||
prompt_attention_mask = prompt_attention_mask.repeat(repeat_per_prompt, 1)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask
|
||||
else:
|
||||
return prompt_embeds
|
||||
|
||||
|
||||
class LTXTextEncoderStep(PipelineBlock):
|
||||
model_name = "ltx"
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Encode text into text embeddings"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[Component]:
|
||||
return [
|
||||
Component(name="text_encoder", T5EncoderModel),
|
||||
Component(name="tokenizer", T5TokenizerFast),
|
||||
Component(name="guider", ClassifierFreeGuidance, config=FrozenDict({"guidance_scale":3.0})),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="prompt", required=True),
|
||||
InputParam(name="negative_prompt"),
|
||||
InputParam(name="num_videos_per_prompt"),
|
||||
InputParam(name="max_sequence_length"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(name="prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The text embeddings."),
|
||||
OutputParam(name="negative_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The negative text embeddings."),
|
||||
OutputParam(name="prompt_attention_mask", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The attention mask for the prompt."),
|
||||
OutputParam(name="negative_prompt_attention_mask", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The attention mask for the negative prompt."),
|
||||
]
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt
|
||||
):
|
||||
|
||||
if (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and (not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)):
|
||||
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
|
||||
|
||||
|
||||
|
||||
def __call__(self, components: LTXModularPipeline, state: PipelineState):
|
||||
|
||||
block_state = state.get_block_state(self)
|
||||
|
||||
self.check_inputs(block_state.prompt, block_state.negative_prompt)
|
||||
block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
|
||||
|
||||
device = components._execution_device
|
||||
dtype = components.text_encoder.dtype
|
||||
|
||||
|
||||
block_state.prompt = [block_state.prompt] if isinstance(block_state.prompt, str) else block_state.prompt
|
||||
batch_size = len(block_state.prompt)
|
||||
|
||||
block_state.prompt_embeds, block_state.prompt_attention_mask = get_t5_prompt_embeds(
|
||||
text_encoder=components.text_encoder,
|
||||
tokenizer=components.tokenizer,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
prompt=block_state.prompt,
|
||||
repeat_per_prompt=block_state.num_videos_per_prompt,
|
||||
max_sequence_length=block_state.max_sequence_length,
|
||||
return_attention_mask=True,
|
||||
)
|
||||
|
||||
if block_state.prepare_unconditional_embeds:
|
||||
block_state.negative_prompt = block_state.negative_prompt or ""
|
||||
block_state.negative_prompt = batch_size * [block_state.negative_prompt] if isinstance(block_state.negative_prompt, str) else block_state.negative_prompt
|
||||
|
||||
|
||||
if batch_size != len(block_state.negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {block_state.negative_prompt} has batch size {len(block_state.negative_prompt)}, but `prompt`:"
|
||||
f" {block_state.prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
block_state.negative_prompt_embeds, block_state.negative_prompt_attention_mask = get_t5_prompt_embeds(
|
||||
text_encoder=components.text_encoder,
|
||||
tokenizer=components.tokenizer,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
prompt=negative_prompt,
|
||||
repeat_per_prompt=block_state.num_videos_per_prompt,
|
||||
max_sequence_length=block_state.max_sequence_length,
|
||||
return_attention_mask=True,
|
||||
)
|
||||
else:
|
||||
block_state.negative_prompt_embeds = None
|
||||
block_state.negative_prompt_attention_mask = None
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class LTXVaeEencoderStep(PipelineBlock):
|
||||
model_name = "ltx"
|
||||
|
||||
@staticmethod
|
||||
def trim_conditioning_sequence(start_frame: int, sequence_num_frames: int, target_num_frames: int, scale_factor: int):
|
||||
"""
|
||||
Trim a conditioning sequence to the allowed number of frames.
|
||||
|
||||
Args:
|
||||
start_frame (int): The target frame number of the first frame in the sequence.
|
||||
sequence_num_frames (int): The number of frames in the sequence.
|
||||
target_num_frames (int): The target number of frames in the generated video.
|
||||
scale_factor (int): The temporal scale factor for the model.
|
||||
Returns:
|
||||
int: updated sequence length
|
||||
|
||||
Example:
|
||||
If you want to create a video of 16 frames (target_num_frames=16),
|
||||
have a condition with 8 frames (sequence_num_frames=8),
|
||||
and want to start conditioning at frame 4 (start_frame=4)
|
||||
with scale_factor=4:
|
||||
|
||||
- Available frames: 16 - 4 = 12 frames remaining
|
||||
- Sequence fits: min(8, 12) = 8 frames
|
||||
- Trim to scale factor: (8-1) // 4 * 4 + 1 = 7 // 4 * 4 + 1 = 1 * 4 + 1 = 5 frames
|
||||
- Result: Condition will use 5 frames starting at frame 4
|
||||
"""
|
||||
num_frames = min(sequence_num_frames, target_num_frames - start_frame)
|
||||
# Trim down to a multiple of temporal_scale_factor frames plus 1
|
||||
num_frames = (num_frames - 1) // scale_factor * scale_factor + 1
|
||||
return num_frames
|
||||
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Encode the image or video inputs into latents."
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[Component]:
|
||||
return [
|
||||
ComponentSpec(name="video_processor", VideoProcessor, config=FrozenDict({"vae_scale_factor": 32})),
|
||||
ComponentSpec(name="vae", AutoencoderKLLTXVideo),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="conditions", required=True),
|
||||
InputParam(name="height", required=True),
|
||||
InputParam(name="width", required=True),
|
||||
InputParam(name="num_frames", required=True),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="generator", required=True),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(name="conditioning_latents", type_hint=List[torch.Tensor], description="The conditioning latents."),
|
||||
OutputParam(name="conditioning_num_frames", type_hint=List[int], description="The number of frames in the conditioning data (before encoding)."),
|
||||
]
|
||||
|
||||
def __call__(self, components: LTXModularPipeline, state: PipelineState):
|
||||
block_state = state.get_block_state(state)
|
||||
|
||||
device = components._execution_device
|
||||
dtype = components.vae.dtype
|
||||
|
||||
latent_mean = components.vae.latents_mean.view(1, -1, 1, 1, 1).to(device, dtype)
|
||||
latent_std = components.vae.latents_std.view(1, -1, 1, 1, 1).to(device, dtype)
|
||||
|
||||
conditioning_latents = []
|
||||
conditioning_num_frames = []
|
||||
for condition in block_state.conditions:
|
||||
if condition.condition_type == "image":
|
||||
condition_tensor = components.video_processor.preprocess(condition.image, block_state.height, block_state.width).unsqueeze(2).to(device,dtype)
|
||||
elif condition.condition_type == "video":
|
||||
condition_tensor = components.video_processor.preprocess(condition.video, block_state.height, block_state.width)
|
||||
num_frames_input = condition_tensor.size(2)
|
||||
num_frames_output = self.trim_conditioning_sequence(start_frame=condition.frame_index, sequence_num_frames=num_frames_input, target_num_frames=block_state.num_frames, scale_factor=components.vae_temporal_compression_ratio)
|
||||
condition_tensor = condition_tensor[:,:,num_frames_output]
|
||||
condition_tensor = condition_tensor.to(device,dtype)
|
||||
|
||||
cond_num_frames = condition_tensor.size(2)
|
||||
if cond_num_frames % components.vae_temporal_compression_ratio != 1:
|
||||
raise ValueError(
|
||||
f"Number of frames in the video must be of the form (k * {components.vae_temporal_compression_ratio} + 1) "
|
||||
f"but got {cond_num_frames} frames."
|
||||
)
|
||||
|
||||
cond_latent = retrieve_latents(components.vae.encode(condition_tensor), generator=block_state.generator)
|
||||
cond_latent = (cond_latent - latent_mean) * 1.0 / latent_std
|
||||
|
||||
conditioning_latents.append(cond_latent)
|
||||
conditioning_num_frames.append(cond_num_frames)
|
||||
|
||||
block_state.conditioning_latents = conditioning_latents
|
||||
block_state.conditioning_num_frames = conditioning_num_frames
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class LTXSetTimeStepsStep(PipelineBlock):
|
||||
model_name = "ltx"
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Set the time steps for the video generation."
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[Component]:
|
||||
return [
|
||||
ComponentSpec(name="scheduler", FlowMatchEulerDiscreteScheduler),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="num_inference_steps", required=True),
|
||||
InputParam(name="timesteps", required=True),
|
||||
InputParam(name="denoise_strength", required=True),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(name="timesteps", type_hint=List[int], description="The timesteps to use for inference."),
|
||||
OutputParam(name="num_inference_steps", type_hint=int, description="The number of inference steps."),
|
||||
OutputParam(name="sigmas", type_hint=List[float], description="The sigmas to use for inference."),
|
||||
OutputParam(name="latent_sigma", type_hint=torch.Tensor, description="The latent sigma to use for preparing the latents."),
|
||||
]
|
||||
|
||||
|
||||
def __call__(self, components: LTXModularPipeline, state: PipelineState):
|
||||
block_state = state.get_block_state(state)
|
||||
|
||||
|
||||
if block_state.timesteps is None:
|
||||
sigmas = linear_quadratic_schedule(block_state.num_inference_steps)
|
||||
timesteps = sigmas * 1000
|
||||
else:
|
||||
timesteps = block_state.timesteps
|
||||
|
||||
device = components._execution_device
|
||||
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(components.scheduler, block_state.num_inference_steps, device, timesteps)
|
||||
block_state.sigmas = components.scheduler.sigmas
|
||||
|
||||
block_state.latent_sigma = None
|
||||
if block_state.denoise_strength < 1:
|
||||
num_steps = min(int(block_state.num_inference_steps * block_state.denoise_strength), block_state.num_inference_steps)
|
||||
start_index = max(block_state.num_inference_steps - num_steps, 0)
|
||||
block_state.sigmas = block_state.sigmas[start_index:]
|
||||
block_state.timesteps = block_state.timesteps[start_index:]
|
||||
block_state.num_inference_steps = block_state.num_inference_steps - start_index
|
||||
block_state.latent_sigma = block_state.sigmas[:1]
|
||||
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class LTXPrepareLatentsStep(PipelineBlock):
|
||||
model_name = "ltx"
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Prepare the latents for the video generation."
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="latents"),
|
||||
InputParam(name="num_frames", required=True),
|
||||
InputParam(name="height", required=True),
|
||||
InputParam(name="width", required=True),
|
||||
InputParam(name="conditions")
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="conditioning_latents"),
|
||||
InputParam(name="conditioning_num_frames"),
|
||||
InputParam(name="batch_size"),
|
||||
InputParam(name="latent_sigma", required=True),
|
||||
InputParam(name="generator", required=True),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(name="latents", type_hint=torch.Tensor, description="The latents to use for the video generation."),
|
||||
OutputParam(name="conditioning_mask", type_hint=torch.Tensor, description="The conditioning mask to use for the video generation."),
|
||||
OutputParam(name="extra_conditioning_latents_num_channels", type_hint=int, description="The number of channels in the extra conditioning latents."),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents
|
||||
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
|
||||
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
|
||||
# The patch dimensions are then permuted and collapsed into the channel dimension of shape:
|
||||
# [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
|
||||
# dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features
|
||||
batch_size, num_channels, num_frames, height, width = latents.shape
|
||||
post_patch_num_frames = num_frames // patch_size_t
|
||||
post_patch_height = height // patch_size
|
||||
post_patch_width = width // patch_size
|
||||
latents = latents.reshape(
|
||||
batch_size,
|
||||
-1,
|
||||
post_patch_num_frames,
|
||||
patch_size_t,
|
||||
post_patch_height,
|
||||
patch_size,
|
||||
post_patch_width,
|
||||
patch_size,
|
||||
)
|
||||
latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def _prepare_video_ids(
|
||||
batch_size: int,
|
||||
num_frames: int,
|
||||
height: int,
|
||||
width: int,
|
||||
patch_size: int = 1,
|
||||
patch_size_t: int = 1,
|
||||
device: torch.device = None,
|
||||
) -> torch.Tensor:
|
||||
latent_sample_coords = torch.meshgrid(
|
||||
torch.arange(0, num_frames, patch_size_t, device=device),
|
||||
torch.arange(0, height, patch_size, device=device),
|
||||
torch.arange(0, width, patch_size, device=device),
|
||||
indexing="ij",
|
||||
)
|
||||
latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
|
||||
latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||
latent_coords = latent_coords.reshape(batch_size, -1, num_frames * height * width)
|
||||
|
||||
return latent_coords
|
||||
|
||||
@staticmethod
|
||||
def _scale_video_ids(
|
||||
video_ids: torch.Tensor,
|
||||
scale_factor: int = 32,
|
||||
scale_factor_t: int = 8,
|
||||
frame_index: int = 0,
|
||||
device: torch.device = None,
|
||||
) -> torch.Tensor:
|
||||
scaled_latent_coords = (
|
||||
video_ids
|
||||
* torch.tensor([scale_factor_t, scale_factor, scale_factor], device=video_ids.device)[None, :, None]
|
||||
)
|
||||
scaled_latent_coords[:, 0] = (scaled_latent_coords[:, 0] + 1 - scale_factor_t).clamp(min=0)
|
||||
scaled_latent_coords[:, 0] += frame_index
|
||||
|
||||
return scaled_latent_coords
|
||||
|
||||
def __call__(self, components: LTXModularPipeline, state: PipelineState):
|
||||
block_state = state.get_block_state(state)
|
||||
|
||||
device = components._execution_device
|
||||
dtype = torch.float32
|
||||
num_prefix_latent_frames = 2 # hardcoded
|
||||
|
||||
batch_size = block_state.batch_size
|
||||
|
||||
patch_size = components.transformer_spatial_patch_size
|
||||
patch_size_t = components.transformer_temporal_patch_size
|
||||
|
||||
num_latent_frames = (block_state.num_frames - 1) // components.vae_temporal_compression_ratio + 1
|
||||
latent_height = block_state.height // components.vae_spatial_compression_ratio
|
||||
latent_width = block_state.width // components.vae_spatial_compression_ratio
|
||||
num_channels_latents = components.num_channels_latents
|
||||
|
||||
shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
|
||||
|
||||
noise = randn_tensor(shape, generator=block_state.generator, device=device, dtype=dtype)
|
||||
latent_sigma = block_state.latent_sigma.repeat(batch_size).to(device, dtype)
|
||||
|
||||
if block_state.latents is not None and block_state.latents.shape != shape:
|
||||
raise ValueError(
|
||||
f"Latents shape {block_state.latents.shape} does not match expected shape {shape}. Please check the input."
|
||||
)
|
||||
block_state.latents = block_state.latents.to(device=device, dtype=dtype)
|
||||
block_state.latents = latent_sigma * block_state.noise + (1 - latent_sigma) * block_state.latents
|
||||
else:
|
||||
block_state.latents = noise
|
||||
|
||||
block_state.conditioning_mask = None
|
||||
block_state.extra_conditioning_latents_num_channels = 0
|
||||
block_state.extra_conditioning_latents = []
|
||||
block_state.extra_conditioning_mask = []
|
||||
|
||||
if block_state.conditioning_latents is not None and block_state.conditioning_num_frames is not None and block_state.conditions is not None:
|
||||
block_state.conditioning_mask = torch.zeros(
|
||||
(batch_size, num_latent_frames), device=device, dtype=dtype
|
||||
)
|
||||
|
||||
for condition_latents, num_data_frames, condition in zip(block_state.conditioning_latents, block_state.conditioning_num_frames, block_state.conditions):
|
||||
|
||||
strength = condition.strength
|
||||
frame_index = condition.frame_index
|
||||
|
||||
condition_latents = condition_latents.to(device, dtype=dtype)
|
||||
num_cond_frames = condition_latents.size(2)
|
||||
|
||||
if frame_index == 0:
|
||||
block_state.latents[:, :, :num_cond_frames] = torch.lerp(
|
||||
block_state.latents[:, :, :num_cond_frames], condition_latents, strength
|
||||
)
|
||||
block_state.conditioning_mask[:, :num_cond_frames] = strength
|
||||
|
||||
else:
|
||||
if num_data_frames > 1:
|
||||
if num_cond_frames < num_prefix_latent_frames:
|
||||
raise ValueError(
|
||||
f"Number of latent frames must be at least {num_prefix_latent_frames} but got {num_data_frames}."
|
||||
)
|
||||
|
||||
if num_cond_frames > num_prefix_latent_frames:
|
||||
start_frame = frame_index // components.vae_temporal_compression_ratio + num_prefix_latent_frames
|
||||
end_frame = start_frame + num_cond_frames - num_prefix_latent_frames
|
||||
block_state.latents[:, :, start_frame:end_frame] = torch.lerp(
|
||||
block_state.latents[:, :, start_frame:end_frame],
|
||||
condition_latents[:, :, num_prefix_latent_frames:],
|
||||
strength,
|
||||
)
|
||||
block_state.conditioning_mask[:, start_frame:end_frame] = strength
|
||||
condition_latents = condition_latents[:, :, :num_prefix_latent_frames]
|
||||
|
||||
noise = randn_tensor(condition_latents.shape, generator=block_state.generator, device=device, dtype=dtype)
|
||||
condition_latents = torch.lerp(noise, condition_latents, strength)
|
||||
|
||||
|
||||
condition_latents = self._pack_latents(
|
||||
condition_latents,
|
||||
patch_size,
|
||||
patch_size_t,
|
||||
)
|
||||
condition_latents_mask = torch.full(
|
||||
condition_latents.shape[:2], strength, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
block_state.extra_conditioning_latents.append(condition_latents)
|
||||
block_state.extra_conditioning_mask.append(condition_latents_mask)
|
||||
block_state.extra_conditioning_latents_num_channels += condition_latents.size(1)
|
||||
|
||||
|
||||
block_state.latents = self._pack_latents(
|
||||
block_state.latents, patch_size, patch_size_t
|
||||
)
|
||||
if block_state.conditioning_mask is not None:
|
||||
block_state.conditioning_mask = block_state.conditioning_mask.reshape(batch_size, 1, num_latent_frames, 1, 1).expand(-1, -1, -1, latent_height, latent_width)
|
||||
block_state.conditioning_mask = self._pack_latents(block_state.conditioning_mask, patch_size, patch_size_t)
|
||||
block_state.conditioning_mask = block_state.conditioning_mask.squeeze(-1)
|
||||
block_state.conditioning_mask = torch.cat([*block_state.extra_conditioning_mask, block_state.conditioning_mask], dim=1)
|
||||
|
||||
|
||||
block_state.latents = torch.cat([*block_state.extra_conditioning_latents, block_state.latents], dim=1)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
111
src/diffusers/modular_pipelines/ltx/modular_pipeline.py
Normal file
111
src/diffusers/modular_pipelines/ltx/modular_pipeline.py
Normal file
@@ -0,0 +1,111 @@
|
||||
# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
from transformers import T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin
|
||||
from ...models.autoencoders import AutoencoderKLLTXVideo
|
||||
from ...models.transformers import LTXVideoTransformer3DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import ModularPipeline
|
||||
from .pipeline_output import LTXPipelineOutput
|
||||
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class LTXVideoCondition:
|
||||
"""
|
||||
Defines a single frame-conditioning item for LTX Video - a single frame or a sequence of frames.
|
||||
|
||||
Attributes:
|
||||
image (`PIL.Image.Image`):
|
||||
The image to condition the video on.
|
||||
video (`List[PIL.Image.Image]`):
|
||||
The video to condition the video on.
|
||||
frame_index (`int`):
|
||||
The frame index at which the image or video will conditionally effect the video generation.
|
||||
strength (`float`, defaults to `1.0`):
|
||||
The strength of the conditioning effect. A value of `1.0` means the conditioning effect is fully applied.
|
||||
"""
|
||||
|
||||
image: Optional[PIL.Image.Image] = None
|
||||
video: Optional[List[PIL.Image.Image]] = None
|
||||
frame_index: int = 0
|
||||
strength: float = 1.0
|
||||
|
||||
|
||||
|
||||
|
||||
class LTXModularPipeline(ModularPipeline, LTXVideoLoraLoaderMixin):
|
||||
r"""
|
||||
Modular Pipeline for LTX Video generation.
|
||||
|
||||
Reference: https://github.com/Lightricks/LTX-Video
|
||||
|
||||
"""
|
||||
|
||||
@property
|
||||
def vae_spatial_compression_ratio(self):
|
||||
return (
|
||||
self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32
|
||||
)
|
||||
|
||||
@property
|
||||
def vae_temporal_compression_ratio(self):
|
||||
return (
|
||||
self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
|
||||
)
|
||||
|
||||
@property
|
||||
def transformer_spatial_patch_size(self):
|
||||
return (
|
||||
self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1
|
||||
)
|
||||
|
||||
@property
|
||||
def transformer_temporal_patch_size(self):
|
||||
return (
|
||||
self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1
|
||||
)
|
||||
|
||||
|
||||
@property
|
||||
def default_height(self):
|
||||
return 512
|
||||
|
||||
@property
|
||||
def default_width(self):
|
||||
return 704
|
||||
|
||||
@property
|
||||
def default_frames(self):
|
||||
return 121
|
||||
|
||||
@property
|
||||
def num_channels_latents(self):
|
||||
return self.transformer.config.in_channels if getattr(self, "transformer", None) is not None else 128
|
||||
Reference in New Issue
Block a user