Compare commits

..

65 Commits

Author SHA1 Message Date
sayakpaul
403d3f20f7 small nits. 2026-03-05 14:05:53 +05:30
Sayak Paul
441224ac00 Merge branch 'main' into rae 2026-03-05 12:27:28 +05:30
dg845
af0bed007a Merge branch 'main' into rae 2026-03-04 17:04:49 -08:00
Ando
ed9bcfd7a9 Merge branch 'huggingface:main' into rae 2026-03-04 19:21:12 +08:00
Kashif Rasul
05d3edca66 use randn_tensor 2026-03-04 10:16:07 +00:00
Kashif Rasul
f4ec0f1443 remove unittest 2026-03-04 10:12:40 +00:00
Kashif Rasul
fa016b196c rename 2026-03-04 09:55:54 +00:00
Kashif Rasul
33d98a85da fix api 2026-03-04 09:55:25 +00:00
Kashif Rasul
14d918ee88 Merge branch 'main' into rae 2026-03-04 10:18:06 +01:00
Kashif Rasul
bc59324a2f Update src/diffusers/models/autoencoders/autoencoder_rae.py
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
2026-03-04 10:12:50 +01:00
Kashif Rasul
b9a5266cec _noising takes a generator 2026-03-04 09:12:19 +00:00
Kashif Rasul
876e930780 remove optional 2026-03-04 09:09:09 +00:00
Kashif Rasul
df1af7d907 Update src/diffusers/models/autoencoders/autoencoder_rae.py
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
2026-03-04 10:04:23 +01:00
Kashif Rasul
af75d8b9e2 inline 2026-03-04 09:03:37 +00:00
Kashif Rasul
e805be989e use buffer 2026-03-04 09:00:09 +00:00
Kashif Rasul
3958fda3bf Update src/diffusers/models/autoencoders/autoencoder_rae.py
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
2026-03-04 09:53:33 +01:00
Kashif Rasul
196f8a36c7 error out as soon as possible and add comments 2026-03-04 08:52:08 +00:00
Sayak Paul
9c0f96b303 Merge branch 'main' into rae 2026-03-03 17:06:14 +05:30
Kashif Rasul
bc71889852 update training script 2026-03-03 09:10:58 +00:00
Kashif Rasul
3a6689518f add dispatch forward and update conversion script 2026-03-03 09:03:28 +00:00
Kashif Rasul
5817416a19 fix test 2026-03-02 08:11:31 +00:00
Kashif Rasul
e834e498b2 _strip_final_layernorm_affine for training script 2026-02-28 19:40:19 +00:00
Kashif Rasul
f15873af72 strip final layernorm when converting 2026-02-28 19:35:21 +00:00
Sayak Paul
bff48d317e Merge branch 'main' into rae 2026-02-28 22:01:01 +05:30
Kashif Rasul
cd86873ea6 make quality 2026-02-28 16:28:04 +00:00
Kashif Rasul
34787e5b9b use ModelTesterMixin and AutoencoderTesterMixin 2026-02-28 16:22:47 +00:00
Kashif Rasul
9ada5768e5 remove config 2026-02-28 16:05:19 +00:00
Kashif Rasul
8861a8082a fix slow test 2026-02-28 15:57:10 +00:00
Kashif Rasul
03e757ca73 Encoder is frozen 2026-02-28 15:35:28 +00:00
Kashif Rasul
c717498fa3 use image url 2026-02-28 15:08:56 +00:00
Kashif Rasul
1b4a43f59d Update src/diffusers/models/autoencoders/autoencoder_rae.py
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-02-27 11:43:20 +01:00
Kashif Rasul
6a78767864 Update examples/research_projects/autoencoder_rae/README.md
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-02-27 11:42:45 +01:00
Kashif Rasul
663b580418 latebt normalization buffers are now always registered with no-op defaults 2026-02-26 10:45:30 +00:00
Kashif Rasul
d965cabe79 fix conversion script review 2026-02-26 10:44:27 +00:00
Kashif Rasul
5c85781519 fix train script to use pretrained 2026-02-26 10:38:47 +00:00
Kashif Rasul
c71cb44299 Merge branch 'rae' of https://github.com/Ando233/diffusers into rae 2026-02-26 10:30:32 +00:00
Kashif Rasul
dca59233f6 address reviews 2026-02-26 10:30:26 +00:00
Kashif Rasul
b3ffd6344a cleanups 2026-02-26 10:26:30 +00:00
Kashif Rasul
7debd07541 Merge branch 'main' into rae 2026-02-26 11:08:08 +01:00
Kashif Rasul
b297868201 fixes from pretrained weights 2026-02-25 13:38:22 +00:00
Kashif Rasul
28a02eb226 undo last change 2026-02-23 10:05:24 +00:00
Kashif Rasul
61885f37e3 added encoder_image_size config 2026-02-23 09:59:26 +00:00
Kashif Rasul
c68b812cb0 fix entrypoint for instantiating the AutoencoderRAE 2026-02-23 09:40:18 +00:00
Kashif Rasul
d8b2983b9e Merge branch 'main' into rae 2026-02-17 10:10:40 +01:00
Kashif Rasul
d06b501850 fix training script 2026-02-16 13:00:00 +00:00
Kashif Rasul
a4fc9f64b2 simplify mixins 2026-02-16 12:52:20 +00:00
Kashif Rasul
fc5295951a cleanup 2026-02-16 12:40:36 +00:00
Kashif Rasul
96520c4ff1 move loss to training script 2026-02-16 12:35:18 +00:00
Kashif Rasul
d3cbd5a60b fix argument 2026-02-16 00:03:54 +00:00
Kashif Rasul
906d79a432 input and ground truth sizes have to be the same 2026-02-16 00:02:27 +00:00
Kashif Rasul
9522e68a5b example traiing script 2026-02-15 23:56:19 +00:00
Kashif Rasul
6a9bde6964 remove unneeded class 2026-02-15 23:55:06 +00:00
Kashif Rasul
e6d449933d use attention 2026-02-15 23:50:52 +00:00
Kashif Rasul
7cbbf271f3 use imports 2026-02-15 23:33:30 +00:00
Kashif Rasul
202b14f6a4 add rae to diffusers script 2026-02-15 23:19:53 +00:00
Kashif Rasul
0d59b22732 cleanup 2026-02-15 23:19:13 +00:00
Kashif Rasul
d7cb12470b use mean and std convention 2026-02-15 22:57:02 +00:00
Kashif Rasul
f06ea7a901 fix latent_mean / latent_var init types to accept config-friendly inputs 2026-02-15 22:51:36 +00:00
Kashif Rasul
25bc9e334c initial doc 2026-02-15 22:44:46 +00:00
Kashif Rasul
24acab0bcc make fix-copies 2026-02-15 22:44:16 +00:00
Kashif Rasul
0850c8cdc9 fix formatting 2026-02-15 22:39:59 +00:00
Kashif Rasul
3ecf89d044 Merge branch 'main' into rae 2026-02-15 23:05:44 +01:00
Ando
a3926d77d7 Merge branch 'main' into rae 2026-01-28 20:31:20 +08:00
wangyuqi
f82cecc298 feat: finish first version of autoencoder_rae 2026-01-28 20:19:31 +08:00
wangyuqi
382aad0a6c feat: implement three RAE encoders(dinov2, siglip2, mae) 2026-01-25 02:54:35 +08:00
20 changed files with 84 additions and 1817 deletions

View File

@@ -11,7 +11,7 @@ specific language governing permissions and limitations under the License. -->
# HeliosTransformer3DModel
A 14B Real-Time Autogressive Diffusion Transformer model (support T2V, I2V and V2V) for 3D video-like data from [Helios](https://github.com/PKU-YuanGroup/Helios) was introduced in [Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/2603.04379) by Peking University & ByteDance & etc.
A 14B Real-Time Autogressive Diffusion Transformer model (support T2V, I2V and V2V) for 3D video-like data from [Helios](https://github.com/PKU-YuanGroup/Helios) was introduced in [Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/) by Peking University & ByteDance & etc.
The model can be loaded with the following code snippet.

View File

@@ -22,7 +22,7 @@
# Helios
[Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/2603.04379) from Peking University & ByteDance & etc, by Shenghai Yuan, Yuanyang Yin, Zongjian Li, Xinwei Huang, Xiao Yang, Li Yuan.
[Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/) from Peking University & ByteDance & etc, by Shenghai Yuan, Yuanyang Yin, Zongjian Li, Xinwei Huang, Xiao Yang, Li Yuan.
* <u>We introduce Helios, the first 14B video generation model that runs at 17 FPS on a single NVIDIA H100 GPU and supports minute-scale generation while matching a strong baseline in quality.</u> We make breakthroughs along three key dimensions: (1) robustness to long-video drifting without commonly used anti-drift heuristics such as self-forcing, error banks, or keyframe sampling; (2) real-time generation without standard acceleration techniques such as KV-cache, causal masking, or sparse attention; and (3) training without parallelism or sharding frameworks, enabling image-diffusion-scale batch sizes while fitting up to four 14B models within 80 GB of GPU memory. Specifically, Helios is a 14B autoregressive diffusion model with a unified input representation that natively supports T2V, I2V, and V2V tasks. To mitigate drifting in long-video generation, we characterize its typical failure modes and propose simple yet effective training strategies that explicitly simulate drifting during training, while eliminating repetitive motion at its source. For efficiency, we heavily compress the historical and noisy context and reduce the number of sampling steps, yielding computational costs comparable to—or lower than—those of 1.3B video generative models. Moreover, we introduce infrastructure-level optimizations that accelerate both inference and training while reducing memory consumption. Extensive experiments demonstrate that Helios consistently outperforms prior methods on both short- and long-video generation. All the code and models are available at [this https URL](https://pku-yuangroup.github.io/Helios-Page).

View File

@@ -193,179 +193,6 @@ encode_video(
)
```
## Condition Pipeline Generation
You can use `LTX2ConditionPipeline` to specify image and/or video conditions at arbitrary latent indices. For example, we can specify both a first-frame and last-frame condition to perform first-last-frame-to-video (FLF2V) generation:
```py
import torch
from diffusers import LTX2ConditionPipeline, LTX2LatentUpsamplePipeline
from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
from diffusers.pipelines.ltx2.pipeline_ltx2_condition import LTX2VideoCondition
from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES
from diffusers.pipelines.ltx2.export_utils import encode_video
from diffusers.utils import load_image
device = "cuda"
width = 768
height = 512
random_seed = 42
generator = torch.Generator(device).manual_seed(random_seed)
model_path = "rootonchair/LTX-2-19b-distilled"
pipe = LTX2ConditionPipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16)
pipe.enable_sequential_cpu_offload(device=device)
pipe.vae.enable_tiling()
prompt = (
"CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are "
"delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright "
"sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, "
"low-angle perspective."
)
first_image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png",
)
last_image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png",
)
first_cond = LTX2VideoCondition(frames=first_image, index=0, strength=1.0)
last_cond = LTX2VideoCondition(frames=last_image, index=-1, strength=1.0)
conditions = [first_cond, last_cond]
frame_rate = 24.0
video_latent, audio_latent = pipe(
conditions=conditions,
prompt=prompt,
width=width,
height=height,
num_frames=121,
frame_rate=frame_rate,
num_inference_steps=8,
sigmas=DISTILLED_SIGMA_VALUES,
guidance_scale=1.0,
generator=generator,
output_type="latent",
return_dict=False,
)
latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(
model_path,
subfolder="latent_upsampler",
torch_dtype=torch.bfloat16,
)
upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler)
upsample_pipe.enable_model_cpu_offload(device=device)
upscaled_video_latent = upsample_pipe(
latents=video_latent,
output_type="latent",
return_dict=False,
)[0]
video, audio = pipe(
latents=upscaled_video_latent,
audio_latents=audio_latent,
prompt=prompt,
width=width * 2,
height=height * 2,
num_inference_steps=3,
sigmas=STAGE_2_DISTILLED_SIGMA_VALUES,
generator=generator,
guidance_scale=1.0,
output_type="np",
return_dict=False,
)
encode_video(
video[0],
fps=frame_rate,
audio=audio[0].float().cpu(),
audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
output_path="ltx2_distilled_flf2v.mp4",
)
```
You can use both image and video conditions:
```py
import torch
from diffusers import LTX2ConditionPipeline
from diffusers.pipelines.ltx2.pipeline_ltx2_condition import LTX2VideoCondition
from diffusers.pipelines.ltx2.export_utils import encode_video
from diffusers.utils import load_image, load_video
device = "cuda"
width = 768
height = 512
random_seed = 42
generator = torch.Generator(device).manual_seed(random_seed)
model_path = "rootonchair/LTX-2-19b-distilled"
pipe = LTX2ConditionPipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16)
pipe.enable_sequential_cpu_offload(device=device)
pipe.vae.enable_tiling()
prompt = (
"The video depicts a long, straight highway stretching into the distance, flanked by metal guardrails. The road is "
"divided into multiple lanes, with a few vehicles visible in the far distance. The surrounding landscape features "
"dry, grassy fields on one side and rolling hills on the other. The sky is mostly clear with a few scattered "
"clouds, suggesting a bright, sunny day. And then the camera switch to a winding mountain road covered in snow, "
"with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The "
"landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the "
"solitude and beauty of a winter drive through a mountainous region."
)
negative_prompt = (
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
)
cond_video = load_video(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input-vid.mp4"
)
cond_image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input.jpg"
)
video_cond = LTX2VideoCondition(frames=cond_video, index=0, strength=1.0)
image_cond = LTX2VideoCondition(frames=cond_image, index=8, strength=1.0)
conditions = [video_cond, image_cond]
frame_rate = 24.0
video, audio = pipe(
conditions=conditions,
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
num_frames=121,
frame_rate=frame_rate,
num_inference_steps=40,
guidance_scale=4.0,
generator=generator,
output_type="np",
return_dict=False,
)
encode_video(
video[0],
fps=frame_rate,
audio=audio[0].float().cpu(),
audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
output_path="ltx2_cond_video.mp4",
)
```
Because the conditioning is done via latent frames, the 8 data space frames corresponding to the specified latent frame for an image condition will tend to be static.
## LTX2Pipeline
[[autodoc]] LTX2Pipeline
@@ -378,12 +205,6 @@ Because the conditioning is done via latent frames, the 8 data space frames corr
- all
- __call__
## LTX2ConditionPipeline
[[autodoc]] LTX2ConditionPipeline
- all
- __call__
## LTX2LatentUpsamplePipeline
[[autodoc]] LTX2LatentUpsamplePipeline

View File

@@ -130,4 +130,4 @@ pipe.to("cuda")
Learn more about Helios with the following resources.
- Watch [video1](https://www.youtube.com/watch?v=vd_AgHtOUFQ) and [video2](https://www.youtube.com/watch?v=1GeIU2Dn7UY) for a demonstration of Helios's key features.
- The research paper, [Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/2603.04379) for more details.
- The research paper, [Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/) for more details.

View File

@@ -131,4 +131,4 @@ pipe.to("cuda")
通过以下资源了解有关 Helios 的更多信息:
- [视频1](https://www.youtube.com/watch?v=vd_AgHtOUFQ)和[视频2](https://www.youtube.com/watch?v=1GeIU2Dn7UY)演示了 Helios 的主要功能;
- 有关更多详细信息,请参阅研究论文 [Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/2603.04379)。
- 有关更多详细信息,请参阅研究论文 [Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/)。

View File

@@ -1715,7 +1715,7 @@ def main(args):
packed_noisy_model_input = Flux2KleinPipeline._pack_latents(noisy_model_input)
# handle guidance
if unwrap_model(transformer).config.guidance_embeds:
if transformer.config.guidance_embeds:
guidance = torch.full([1], args.guidance_scale, device=accelerator.device)
guidance = guidance.expand(model_input.shape[0])
else:

View File

@@ -1682,7 +1682,7 @@ def main(args):
model_input_ids = torch.cat([model_input_ids, cond_model_input_ids], dim=1)
# handle guidance
if unwrap_model(transformer).config.guidance_embeds:
if transformer.config.guidance_embeds:
guidance = torch.full([1], args.guidance_scale, device=accelerator.device)
guidance = guidance.expand(model_input.shape[0])
else:

View File

@@ -335,11 +335,8 @@ def convert(args: argparse.Namespace) -> None:
model_keys |= {name for name, _ in model.named_buffers()}
loaded_keys = set(full_state_dict.keys())
missing = model_keys - loaded_keys
# decoder_pos_embed is initialized in-model. trainable_cls_token is only
# allowed to be missing if it was absent in the source decoder checkpoint.
allowed_missing = {"decoder.decoder_pos_embed"}
if "trainable_cls_token" not in decoder_state_dict:
allowed_missing.add("decoder.trainable_cls_token")
# trainable_cls_token and decoder_pos_embed are initialized, not loaded from original checkpoint
allowed_missing = {"decoder.trainable_cls_token", "decoder.decoder_pos_embed"}
if missing - allowed_missing:
print(f"Warning: missing keys after conversion: {sorted(missing - allowed_missing)}")

View File

@@ -572,7 +572,6 @@ else:
"LEditsPPPipelineStableDiffusionXL",
"LongCatImageEditPipeline",
"LongCatImagePipeline",
"LTX2ConditionPipeline",
"LTX2ImageToVideoPipeline",
"LTX2LatentUpsamplePipeline",
"LTX2Pipeline",
@@ -1321,7 +1320,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LEditsPPPipelineStableDiffusionXL,
LongCatImageEditPipeline,
LongCatImagePipeline,
LTX2ConditionPipeline,
LTX2ImageToVideoPipeline,
LTX2LatentUpsamplePipeline,
LTX2Pipeline,

View File

@@ -108,17 +108,8 @@ class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode):
class AttentionProcessorSkipHook(ModelHook):
def __init__(
self,
skip_processor_output_fn: Callable,
skip_attn_scores_fn: Callable | None = None,
skip_attention_scores: bool = False,
dropout: float = 1.0,
):
super().__init__()
def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False, dropout: float = 1.0):
self.skip_processor_output_fn = skip_processor_output_fn
# STG default: return the values as attention output
self.skip_attn_scores_fn = skip_attn_scores_fn or (lambda attn, q, k, v: v)
self.skip_attention_scores = skip_attention_scores
self.dropout = dropout
@@ -128,22 +119,8 @@ class AttentionProcessorSkipHook(ModelHook):
raise ValueError(
"Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
)
processor_supports_skip_fn = hasattr(module.processor, "_skip_attn_scores")
if processor_supports_skip_fn:
module.processor._skip_attn_scores = True
module.processor._skip_attn_scores_fn = self.skip_attn_scores_fn
# Use try block in case attn processor raises an exception
try:
if processor_supports_skip_fn:
output = self.fn_ref.original_forward(*args, **kwargs)
else:
# Fallback to torch native SDPA intercept approach
with AttentionScoreSkipFunctionMode():
output = self.fn_ref.original_forward(*args, **kwargs)
finally:
if processor_supports_skip_fn:
module.processor._skip_attn_scores = False
module.processor._skip_attn_scores_fn = None
with AttentionScoreSkipFunctionMode():
output = self.fn_ref.original_forward(*args, **kwargs)
else:
if math.isclose(self.dropout, 1.0):
output = self.skip_processor_output_fn(module, *args, **kwargs)

View File

@@ -38,7 +38,6 @@ from ..utils import (
is_flash_attn_available,
is_flash_attn_version,
is_kernels_available,
is_kernels_version,
is_sageattention_available,
is_sageattention_version,
is_torch_npu_available,
@@ -319,7 +318,6 @@ class _HubKernelConfig:
repo_id: str
function_attr: str
revision: str | None = None
version: int | None = None
kernel_fn: Callable | None = None
wrapped_forward_attr: str | None = None
wrapped_backward_attr: str | None = None
@@ -329,34 +327,31 @@ class _HubKernelConfig:
# Registry for hub-based attention kernels
_HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = {
# TODO: temporary revision for now. Remove when merged upstream into `main`.
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn3",
function_attr="flash_attn_func",
revision="fake-ops-return-probs",
wrapped_forward_attr="flash_attn_interface._flash_attn_forward",
wrapped_backward_attr="flash_attn_interface._flash_attn_backward",
version=1,
),
AttentionBackendName._FLASH_3_VARLEN_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn3",
function_attr="flash_attn_varlen_func",
version=1,
# revision="fake-ops-return-probs",
),
AttentionBackendName.FLASH_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn2",
function_attr="flash_attn_func",
revision=None,
wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_forward",
wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_backward",
version=1,
),
AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn2",
function_attr="flash_attn_varlen_func",
version=1,
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", revision=None
),
AttentionBackendName.SAGE_HUB: _HubKernelConfig(
repo_id="kernels-community/sage-attention",
function_attr="sageattn",
version=1,
repo_id="kernels-community/sage_attention", function_attr="sageattn", revision=None
),
}
@@ -526,10 +521,6 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
raise RuntimeError(
f"Backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
)
if not is_kernels_version(">=", "0.12"):
raise RuntimeError(
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12. Please update with `pip install -U kernels`."
)
elif backend == AttentionBackendName.AITER:
if not _CAN_USE_AITER_ATTN:
@@ -703,7 +694,7 @@ def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
try:
from kernels import get_kernel
kernel_module = get_kernel(config.repo_id, revision=config.revision, version=config.version)
kernel_module = get_kernel(config.repo_id, revision=config.revision)
if needs_kernel:
config.kernel_fn = _resolve_kernel_attr(kernel_module, config.function_attr)

View File

@@ -226,17 +226,7 @@ class ViTMAELayer(nn.Module):
class RAEDecoder(nn.Module):
"""
Decoder implementation ported from RAE-main to keep checkpoint compatibility.
Key attributes (must match checkpoint keys):
- decoder_embed
- decoder_pos_embed
- decoder_layers
- decoder_norm
- decoder_pred
- trainable_cls_token
"""
"""Lightweight RAE decoder."""
def __init__(
self,
@@ -263,11 +253,7 @@ class RAEDecoder(nn.Module):
self.num_patches = num_patches
self.decoder_embed = nn.Linear(hidden_size, decoder_hidden_size, bias=True)
grid_size = int(num_patches**0.5)
pos_embed = get_2d_sincos_pos_embed(
decoder_hidden_size, grid_size, cls_token=True, extra_tokens=1, output_type="pt"
)
self.register_buffer("decoder_pos_embed", pos_embed.unsqueeze(0).float(), persistent=False)
self.register_buffer("decoder_pos_embed", torch.zeros(1, num_patches + 1, decoder_hidden_size))
self.decoder_layers = nn.ModuleList(
[
@@ -289,8 +275,27 @@ class RAEDecoder(nn.Module):
self.decoder_pred = nn.Linear(decoder_hidden_size, patch_size**2 * num_channels, bias=True)
self.gradient_checkpointing = False
self._initialize_weights(num_patches)
self.trainable_cls_token = nn.Parameter(torch.zeros(1, 1, decoder_hidden_size))
def _initialize_weights(self, num_patches: int):
# Skip initialization when parameters are on meta device (e.g. during
# accelerate.init_empty_weights() used by low_cpu_mem_usage loading).
# The weights are initialized.
if self.decoder_pos_embed.device.type == "meta":
return
grid_size = int(num_patches**0.5)
pos_embed = get_2d_sincos_pos_embed(
self.decoder_pos_embed.shape[-1],
grid_size,
cls_token=True,
extra_tokens=1,
output_type="pt",
device=self.decoder_pos_embed.device,
)
self.decoder_pos_embed.data.copy_(pos_embed.unsqueeze(0).to(dtype=self.decoder_pos_embed.dtype))
def interpolate_pos_encoding(self, embeddings: torch.Tensor) -> torch.Tensor:
embeddings_positions = embeddings.shape[1] - 1
num_positions = self.decoder_pos_embed.shape[1] - 1
@@ -440,7 +445,6 @@ class AutoencoderRAE(ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin):
# NOTE: gradient checkpointing is not wired up for this model yet.
_supports_gradient_checkpointing = False
_no_split_modules = ["ViTMAELayer"]
_keys_to_ignore_on_load_unexpected = ["decoder.decoder_pos_embed"]
@register_to_config
def __init__(
@@ -473,6 +477,31 @@ class AutoencoderRAE(ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin):
f"Unknown encoder_type='{encoder_type}'. Available: {sorted(_ENCODER_FORWARD_FNS.keys())}"
)
if encoder_input_size % encoder_patch_size != 0:
raise ValueError(
f"encoder_input_size={encoder_input_size} must be divisible by encoder_patch_size={encoder_patch_size}."
)
decoder_patch_size = patch_size
if decoder_patch_size <= 0:
raise ValueError("patch_size must be a positive integer (this is decoder_patch_size).")
num_patches = (encoder_input_size // encoder_patch_size) ** 2
grid = int(sqrt(num_patches))
if grid * grid != num_patches:
raise ValueError(f"Computed num_patches={num_patches} must be a perfect square.")
derived_image_size = decoder_patch_size * grid
if image_size is None:
image_size = derived_image_size
else:
image_size = int(image_size)
if image_size != derived_image_size:
raise ValueError(
f"image_size={image_size} must equal decoder_patch_size*sqrt(num_patches)={derived_image_size} "
f"for patch_size={decoder_patch_size} and computed num_patches={num_patches}."
)
def _to_config_compatible(value: Any) -> Any:
if isinstance(value, torch.Tensor):
return value.detach().cpu().tolist()
@@ -497,21 +526,6 @@ class AutoencoderRAE(ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin):
latents_std=_to_config_compatible(latents_std),
)
self.encoder_input_size = encoder_input_size
self.noise_tau = float(noise_tau)
self.reshape_to_2d = bool(reshape_to_2d)
self.use_encoder_loss = bool(use_encoder_loss)
# Validate early, before building the (potentially large) encoder/decoder.
encoder_patch_size = int(encoder_patch_size)
if self.encoder_input_size % encoder_patch_size != 0:
raise ValueError(
f"encoder_input_size={self.encoder_input_size} must be divisible by encoder_patch_size={encoder_patch_size}."
)
decoder_patch_size = int(patch_size)
if decoder_patch_size <= 0:
raise ValueError("patch_size must be a positive integer (this is decoder_patch_size).")
# Frozen representation encoder (built from config, no downloads)
self.encoder: nn.Module = _build_encoder(
encoder_type=encoder_type,
@@ -520,22 +534,7 @@ class AutoencoderRAE(ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin):
num_hidden_layers=encoder_num_hidden_layers,
)
self._encoder_forward_fn = _ENCODER_FORWARD_FNS[encoder_type]
num_patches = (self.encoder_input_size // encoder_patch_size) ** 2
grid = int(sqrt(num_patches))
if grid * grid != num_patches:
raise ValueError(f"Computed num_patches={num_patches} must be a perfect square.")
derived_image_size = decoder_patch_size * grid
if image_size is None:
image_size = derived_image_size
else:
image_size = int(image_size)
if image_size != derived_image_size:
raise ValueError(
f"image_size={image_size} must equal decoder_patch_size*sqrt(num_patches)={derived_image_size} "
f"for patch_size={decoder_patch_size} and computed num_patches={num_patches}."
)
num_patches = (encoder_input_size // encoder_patch_size) ** 2
# Encoder input normalization stats (ImageNet defaults)
if encoder_norm_mean is None:
@@ -570,6 +569,7 @@ class AutoencoderRAE(ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin):
num_channels=int(num_channels),
image_size=int(image_size),
)
self.num_patches = int(num_patches)
self.decoder_patch_size = int(decoder_patch_size)
self.decoder_image_size = int(image_size)
@@ -579,16 +579,19 @@ class AutoencoderRAE(ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin):
def _noising(self, x: torch.Tensor, generator: torch.Generator | None = None) -> torch.Tensor:
# Per-sample random sigma in [0, noise_tau]
noise_sigma = self.noise_tau * torch.rand(
noise_sigma = self.config.noise_tau * torch.rand(
(x.size(0),) + (1,) * (x.ndim - 1), device=x.device, dtype=x.dtype, generator=generator
)
return x + noise_sigma * randn_tensor(x.shape, generator=generator, device=x.device, dtype=x.dtype)
def _resize_and_normalize(self, x: torch.Tensor) -> torch.Tensor:
_, _, h, w = x.shape
if h != self.encoder_input_size or w != self.encoder_input_size:
if h != self.config.encoder_input_size or w != self.config.encoder_input_size:
x = F.interpolate(
x, size=(self.encoder_input_size, self.encoder_input_size), mode="bicubic", align_corners=False
x,
size=(self.config.encoder_input_size, self.config.encoder_input_size),
mode="bicubic",
align_corners=False,
)
mean = self.encoder_mean.to(device=x.device, dtype=x.dtype)
std = self.encoder_std.to(device=x.device, dtype=x.dtype)
@@ -617,10 +620,10 @@ class AutoencoderRAE(ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin):
else:
tokens = self._encoder_forward_fn(self.encoder, x) # (B, N, C)
if self.training and self.noise_tau > 0:
if self.training and self.config.noise_tau > 0:
tokens = self._noising(tokens, generator=generator)
if self.reshape_to_2d:
if self.config.reshape_to_2d:
b, n, c = tokens.shape
side = int(sqrt(n))
if side * side != n:
@@ -657,7 +660,7 @@ class AutoencoderRAE(ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin):
z = self._denormalize_latents(z)
if self.reshape_to_2d:
if self.config.reshape_to_2d:
b, c, h, w = z.shape
tokens = z.view(b, c, h * w).transpose(1, 2).contiguous() # (B, N, C)
else:
@@ -666,7 +669,7 @@ class AutoencoderRAE(ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin):
logits = self.decoder(tokens, return_dict=True).logits
x_rec = self.decoder.unpatchify(logits)
x_rec = self._denormalize_image(x_rec)
return x_rec.to(device=z.device)
return x_rec
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | tuple[torch.Tensor]:

View File

@@ -292,12 +292,7 @@ else:
"LTXLatentUpsamplePipeline",
"LTXI2VLongMultiPromptPipeline",
]
_import_structure["ltx2"] = [
"LTX2Pipeline",
"LTX2ConditionPipeline",
"LTX2ImageToVideoPipeline",
"LTX2LatentUpsamplePipeline",
]
_import_structure["ltx2"] = ["LTX2Pipeline", "LTX2ImageToVideoPipeline", "LTX2LatentUpsamplePipeline"]
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
_import_structure["lucy"] = ["LucyEditPipeline"]
@@ -736,7 +731,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LTXLatentUpsamplePipeline,
LTXPipeline,
)
from .ltx2 import LTX2ConditionPipeline, LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline, LTX2Pipeline
from .ltx2 import LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline, LTX2Pipeline
from .lucy import LucyEditPipeline
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline

View File

@@ -76,8 +76,6 @@ EXAMPLE_DOC_STRING = """
def optimized_scale(positive_flat, negative_flat):
positive_flat = positive_flat.float()
negative_flat = negative_flat.float()
# Calculate dot production
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
# Squared norm of uncondition

View File

@@ -25,7 +25,6 @@ else:
_import_structure["connectors"] = ["LTX2TextConnectors"]
_import_structure["latent_upsampler"] = ["LTX2LatentUpsamplerModel"]
_import_structure["pipeline_ltx2"] = ["LTX2Pipeline"]
_import_structure["pipeline_ltx2_condition"] = ["LTX2ConditionPipeline"]
_import_structure["pipeline_ltx2_image2video"] = ["LTX2ImageToVideoPipeline"]
_import_structure["pipeline_ltx2_latent_upsample"] = ["LTX2LatentUpsamplePipeline"]
_import_structure["vocoder"] = ["LTX2Vocoder"]
@@ -41,7 +40,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .connectors import LTX2TextConnectors
from .latent_upsampler import LTX2LatentUpsamplerModel
from .pipeline_ltx2 import LTX2Pipeline
from .pipeline_ltx2_condition import LTX2ConditionPipeline
from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline
from .pipeline_ltx2_latent_upsample import LTX2LatentUpsamplePipeline
from .vocoder import LTX2Vocoder

File diff suppressed because it is too large Load Diff

View File

@@ -86,7 +86,6 @@ from .import_utils import (
is_inflect_available,
is_invisible_watermark_available,
is_kernels_available,
is_kernels_version,
is_kornia_available,
is_librosa_available,
is_matplotlib_available,

View File

@@ -2147,21 +2147,6 @@ class LongCatImagePipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class LTX2ConditionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class LTX2ImageToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -724,22 +724,6 @@ def is_transformers_version(operation: str, version: str):
return compare_versions(parse(_transformers_version), operation, version)
@cache
def is_kernels_version(operation: str, version: str):
"""
Compares the current Kernels version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _kernels_available:
return False
return compare_versions(parse(_kernels_version), operation, version)
@cache
def is_hf_hub_version(operation: str, version: str):
"""

View File

@@ -25,9 +25,9 @@ from .image_processor import VaeImageProcessor, is_valid_image, is_valid_image_i
class VideoProcessor(VaeImageProcessor):
r"""Simple video processor."""
def preprocess_video(self, video, height: int | None = None, width: int | None = None, **kwargs) -> torch.Tensor:
def preprocess_video(self, video, height: int | None = None, width: int | None = None) -> torch.Tensor:
r"""
Preprocesses input video(s). Keyword arguments will be forwarded to `VaeImageProcessor.preprocess`.
Preprocesses input video(s).
Args:
video (`list[PIL.Image]`, `list[list[PIL.Image]]`, `torch.Tensor`, `np.array`, `list[torch.Tensor]`, `list[np.array]`):
@@ -49,10 +49,6 @@ class VideoProcessor(VaeImageProcessor):
width (`int`, *optional*`, defaults to `None`):
The width in preprocessed frames of the video. If `None`, will use get_default_height_width()` to get
the default width.
Returns:
`torch.Tensor` of shape `(batch_size, num_channels, num_frames, height, width)`:
A 5D tensor holding the batched channels-first video(s).
"""
if isinstance(video, list) and isinstance(video[0], np.ndarray) and video[0].ndim == 5:
warnings.warn(
@@ -83,7 +79,7 @@ class VideoProcessor(VaeImageProcessor):
"Input is in incorrect format. Currently, we only support numpy.ndarray, torch.Tensor, PIL.Image.Image"
)
video = torch.stack([self.preprocess(img, height=height, width=width, **kwargs) for img in video], dim=0)
video = torch.stack([self.preprocess(img, height=height, width=width) for img in video], dim=0)
# move the number of channels before the number of frames.
video = video.permute(0, 2, 1, 3, 4)
@@ -91,11 +87,10 @@ class VideoProcessor(VaeImageProcessor):
return video
def postprocess_video(
self, video: torch.Tensor, output_type: str = "np", **kwargs
self, video: torch.Tensor, output_type: str = "np"
) -> np.ndarray | torch.Tensor | list[PIL.Image.Image]:
r"""
Converts a video tensor to a list of frames for export. Keyword arguments will be forwarded to
`VaeImageProcessor.postprocess`.
Converts a video tensor to a list of frames for export.
Args:
video (`torch.Tensor`): The video as a tensor.
@@ -105,7 +100,7 @@ class VideoProcessor(VaeImageProcessor):
outputs = []
for batch_idx in range(batch_size):
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
batch_output = self.postprocess(batch_vid, output_type, **kwargs)
batch_output = self.postprocess(batch_vid, output_type)
outputs.append(batch_output)
if output_type == "np":