mirror of
https://github.com/huggingface/diffusers.git
synced 2026-04-16 04:37:07 +08:00
Compare commits
5 Commits
docs/model
...
xla-autoen
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b0b8008fa9 | ||
|
|
d30831683c | ||
|
|
c41a3c3ed8 | ||
|
|
0d79fc2e60 | ||
|
|
e4d219b366 |
@@ -490,6 +490,8 @@
|
||||
- sections:
|
||||
- local: api/pipelines/audioldm2
|
||||
title: AudioLDM 2
|
||||
- local: api/pipelines/longcat_audio_dit
|
||||
title: LongCat-AudioDiT
|
||||
- local: api/pipelines/stable_audio
|
||||
title: Stable Audio
|
||||
title: Audio
|
||||
|
||||
61
docs/source/en/api/pipelines/longcat_audio_dit.md
Normal file
61
docs/source/en/api/pipelines/longcat_audio_dit.md
Normal file
@@ -0,0 +1,61 @@
|
||||
<!--Copyright 2026 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# LongCat-AudioDiT
|
||||
|
||||
LongCat-AudioDiT is a text-to-audio diffusion model from Meituan LongCat. The diffusers integration exposes a standard [`DiffusionPipeline`] interface for text-conditioned audio generation.
|
||||
|
||||
This pipeline supports loading the original flat LongCat checkpoint layout from either a local directory or a Hugging Face Hub repository containing:
|
||||
|
||||
- `config.json`
|
||||
- `model.safetensors`
|
||||
|
||||
The loader builds the text encoder, transformer, and VAE from `config.json`, restores component weights from `model.safetensors`, and ties the shared UMT5 embedding when needed.
|
||||
|
||||
This pipeline was adapted from the LongCat-AudioDiT reference implementation: https://github.com/meituan-longcat/LongCat-AudioDiT
|
||||
|
||||
## Usage
|
||||
|
||||
```py
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from diffusers import LongCatAudioDiTPipeline
|
||||
|
||||
pipeline = LongCatAudioDiTPipeline.from_pretrained(
|
||||
"meituan-longcat/LongCat-AudioDiT-1B",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipeline = pipeline.to("cuda")
|
||||
|
||||
audio = pipeline(
|
||||
prompt="A calm ocean wave ambience with soft wind in the background.",
|
||||
audio_end_in_s=5.0,
|
||||
num_inference_steps=16,
|
||||
guidance_scale=4.0,
|
||||
output_type="pt",
|
||||
).audios
|
||||
|
||||
output = audio[0, 0].float().cpu().numpy()
|
||||
sf.write("longcat.wav", output, pipeline.sample_rate)
|
||||
```
|
||||
|
||||
## Tips
|
||||
|
||||
- `audio_end_in_s` is the most direct way to control output duration.
|
||||
- `output_type="pt"` returns a PyTorch tensor shaped `(batch, channels, samples)`.
|
||||
|
||||
## LongCatAudioDiTPipeline
|
||||
|
||||
[[autodoc]] LongCatAudioDiTPipeline
|
||||
- all
|
||||
- __call__
|
||||
- from_pretrained
|
||||
@@ -29,6 +29,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
|
||||
|---|---|
|
||||
| [AnimateDiff](animatediff) | text2video |
|
||||
| [AudioLDM2](audioldm2) | text2audio |
|
||||
| [LongCat-AudioDiT](longcat_audio_dit) | text2audio |
|
||||
| [AuraFlow](aura_flow) | text2image |
|
||||
| [Bria 3.2](bria_3_2) | text2image |
|
||||
| [CogVideoX](cogvideox) | text2video |
|
||||
|
||||
@@ -895,9 +895,8 @@ class TokenEmbeddingsHandler:
|
||||
self.train_ids_t5 = tokenizer.convert_tokens_to_ids(self.inserting_toks)
|
||||
|
||||
# random initialization of new tokens
|
||||
embeds = (
|
||||
text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens
|
||||
)
|
||||
text_module = text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
embeds = text_module.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens
|
||||
std_token_embedding = embeds.weight.data.std()
|
||||
|
||||
logger.info(f"{idx} text encoder's std_token_embedding: {std_token_embedding}")
|
||||
@@ -905,9 +904,7 @@ class TokenEmbeddingsHandler:
|
||||
train_ids = self.train_ids if idx == 0 else self.train_ids_t5
|
||||
# if initializer_concept are not provided, token embeddings are initialized randomly
|
||||
if args.initializer_concept is None:
|
||||
hidden_size = (
|
||||
text_encoder.text_model.config.hidden_size if idx == 0 else text_encoder.encoder.config.hidden_size
|
||||
)
|
||||
hidden_size = text_module.config.hidden_size if idx == 0 else text_encoder.encoder.config.hidden_size
|
||||
embeds.weight.data[train_ids] = (
|
||||
torch.randn(len(train_ids), hidden_size).to(device=self.device).to(dtype=self.dtype)
|
||||
* std_token_embedding
|
||||
@@ -940,7 +937,8 @@ class TokenEmbeddingsHandler:
|
||||
idx_to_text_encoder_name = {0: "clip_l", 1: "t5"}
|
||||
for idx, text_encoder in enumerate(self.text_encoders):
|
||||
train_ids = self.train_ids if idx == 0 else self.train_ids_t5
|
||||
embeds = text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared
|
||||
text_module = text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
embeds = text_module.embeddings.token_embedding if idx == 0 else text_encoder.shared
|
||||
assert embeds.weight.data.shape[0] == len(self.tokenizers[idx]), "Tokenizers should be the same."
|
||||
new_token_embeddings = embeds.weight.data[train_ids]
|
||||
|
||||
@@ -962,7 +960,8 @@ class TokenEmbeddingsHandler:
|
||||
@torch.no_grad()
|
||||
def retract_embeddings(self):
|
||||
for idx, text_encoder in enumerate(self.text_encoders):
|
||||
embeds = text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared
|
||||
text_module = text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
embeds = text_module.embeddings.token_embedding if idx == 0 else text_encoder.shared
|
||||
index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"]
|
||||
embeds.weight.data[index_no_updates] = (
|
||||
self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates]
|
||||
@@ -2112,7 +2111,8 @@ def main(args):
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one.train()
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
|
||||
_te_one = unwrap_model(text_encoder_one)
|
||||
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)
|
||||
elif args.train_text_encoder_ti: # textual inversion / pivotal tuning
|
||||
text_encoder_one.train()
|
||||
if args.enable_t5_ti:
|
||||
|
||||
@@ -763,19 +763,28 @@ class TokenEmbeddingsHandler:
|
||||
self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)
|
||||
|
||||
# random initialization of new tokens
|
||||
std_token_embedding = text_encoder.text_model.embeddings.token_embedding.weight.data.std()
|
||||
std_token_embedding = (
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data.std()
|
||||
|
||||
print(f"{idx} text encoder's std_token_embedding: {std_token_embedding}")
|
||||
|
||||
text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] = (
|
||||
torch.randn(len(self.train_ids), text_encoder.text_model.config.hidden_size)
|
||||
(
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data[self.train_ids] = (
|
||||
torch.randn(
|
||||
len(self.train_ids),
|
||||
(
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).config.hidden_size,
|
||||
)
|
||||
.to(device=self.device)
|
||||
.to(dtype=self.dtype)
|
||||
* std_token_embedding
|
||||
)
|
||||
self.embeddings_settings[f"original_embeddings_{idx}"] = (
|
||||
text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
|
||||
)
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data.clone()
|
||||
self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
|
||||
|
||||
inu = torch.ones((len(tokenizer),), dtype=torch.bool)
|
||||
@@ -794,10 +803,14 @@ class TokenEmbeddingsHandler:
|
||||
# text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14 - TODO - change for sd
|
||||
idx_to_text_encoder_name = {0: "clip_l", 1: "clip_g"}
|
||||
for idx, text_encoder in enumerate(self.text_encoders):
|
||||
assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len(
|
||||
self.tokenizers[0]
|
||||
), "Tokenizers should be the same."
|
||||
new_token_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids]
|
||||
assert (
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data.shape[0] == len(self.tokenizers[0]), (
|
||||
"Tokenizers should be the same."
|
||||
)
|
||||
new_token_embeddings = (
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data[self.train_ids]
|
||||
|
||||
# New tokens for each text encoder are saved under "clip_l" (for text_encoder 0), "clip_g" (for
|
||||
# text_encoder 1) to keep compatible with the ecosystem.
|
||||
@@ -819,7 +832,9 @@ class TokenEmbeddingsHandler:
|
||||
def retract_embeddings(self):
|
||||
for idx, text_encoder in enumerate(self.text_encoders):
|
||||
index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"]
|
||||
text_encoder.text_model.embeddings.token_embedding.weight.data[index_no_updates] = (
|
||||
(
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data[index_no_updates] = (
|
||||
self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates]
|
||||
.to(device=text_encoder.device)
|
||||
.to(dtype=text_encoder.dtype)
|
||||
@@ -830,11 +845,15 @@ class TokenEmbeddingsHandler:
|
||||
std_token_embedding = self.embeddings_settings[f"std_token_embedding_{idx}"]
|
||||
|
||||
index_updates = ~index_no_updates
|
||||
new_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates]
|
||||
new_embeddings = (
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data[index_updates]
|
||||
off_ratio = std_token_embedding / new_embeddings.std()
|
||||
|
||||
new_embeddings = new_embeddings * (off_ratio**0.1)
|
||||
text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] = new_embeddings
|
||||
(
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data[index_updates] = new_embeddings
|
||||
|
||||
|
||||
class DreamBoothDataset(Dataset):
|
||||
@@ -1704,7 +1723,8 @@ def main(args):
|
||||
text_encoder_one.train()
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one.text_model.embeddings.requires_grad_(True)
|
||||
_te_one = text_encoder_one
|
||||
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)
|
||||
|
||||
unet.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
|
||||
@@ -929,19 +929,28 @@ class TokenEmbeddingsHandler:
|
||||
self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)
|
||||
|
||||
# random initialization of new tokens
|
||||
std_token_embedding = text_encoder.text_model.embeddings.token_embedding.weight.data.std()
|
||||
std_token_embedding = (
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data.std()
|
||||
|
||||
print(f"{idx} text encoder's std_token_embedding: {std_token_embedding}")
|
||||
|
||||
text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] = (
|
||||
torch.randn(len(self.train_ids), text_encoder.text_model.config.hidden_size)
|
||||
(
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data[self.train_ids] = (
|
||||
torch.randn(
|
||||
len(self.train_ids),
|
||||
(
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).config.hidden_size,
|
||||
)
|
||||
.to(device=self.device)
|
||||
.to(dtype=self.dtype)
|
||||
* std_token_embedding
|
||||
)
|
||||
self.embeddings_settings[f"original_embeddings_{idx}"] = (
|
||||
text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
|
||||
)
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data.clone()
|
||||
self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
|
||||
|
||||
inu = torch.ones((len(tokenizer),), dtype=torch.bool)
|
||||
@@ -959,10 +968,14 @@ class TokenEmbeddingsHandler:
|
||||
# text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14
|
||||
idx_to_text_encoder_name = {0: "clip_l", 1: "clip_g"}
|
||||
for idx, text_encoder in enumerate(self.text_encoders):
|
||||
assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len(
|
||||
self.tokenizers[0]
|
||||
), "Tokenizers should be the same."
|
||||
new_token_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids]
|
||||
assert (
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data.shape[0] == len(self.tokenizers[0]), (
|
||||
"Tokenizers should be the same."
|
||||
)
|
||||
new_token_embeddings = (
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data[self.train_ids]
|
||||
|
||||
# New tokens for each text encoder are saved under "clip_l" (for text_encoder 0), "clip_g" (for
|
||||
# text_encoder 1) to keep compatible with the ecosystem.
|
||||
@@ -984,7 +997,9 @@ class TokenEmbeddingsHandler:
|
||||
def retract_embeddings(self):
|
||||
for idx, text_encoder in enumerate(self.text_encoders):
|
||||
index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"]
|
||||
text_encoder.text_model.embeddings.token_embedding.weight.data[index_no_updates] = (
|
||||
(
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data[index_no_updates] = (
|
||||
self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates]
|
||||
.to(device=text_encoder.device)
|
||||
.to(dtype=text_encoder.dtype)
|
||||
@@ -995,11 +1010,15 @@ class TokenEmbeddingsHandler:
|
||||
std_token_embedding = self.embeddings_settings[f"std_token_embedding_{idx}"]
|
||||
|
||||
index_updates = ~index_no_updates
|
||||
new_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates]
|
||||
new_embeddings = (
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data[index_updates]
|
||||
off_ratio = std_token_embedding / new_embeddings.std()
|
||||
|
||||
new_embeddings = new_embeddings * (off_ratio**0.1)
|
||||
text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] = new_embeddings
|
||||
(
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data[index_updates] = new_embeddings
|
||||
|
||||
|
||||
class DreamBoothDataset(Dataset):
|
||||
@@ -2083,8 +2102,10 @@ def main(args):
|
||||
text_encoder_two.train()
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
if args.train_text_encoder:
|
||||
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
|
||||
accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)
|
||||
_te_one = accelerator.unwrap_model(text_encoder_one)
|
||||
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)
|
||||
_te_two = accelerator.unwrap_model(text_encoder_two)
|
||||
(_te_two.text_model if hasattr(_te_two, "text_model") else _te_two).embeddings.requires_grad_(True)
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
if pivoted:
|
||||
|
||||
@@ -874,10 +874,11 @@ def main(args):
|
||||
token_embeds[x] = token_embeds[y]
|
||||
|
||||
# Freeze all parameters except for the token embeddings in text encoder
|
||||
text_module = text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
params_to_freeze = itertools.chain(
|
||||
text_encoder.text_model.encoder.parameters(),
|
||||
text_encoder.text_model.final_layer_norm.parameters(),
|
||||
text_encoder.text_model.embeddings.position_embedding.parameters(),
|
||||
text_module.encoder.parameters(),
|
||||
text_module.final_layer_norm.parameters(),
|
||||
text_module.embeddings.position_embedding.parameters(),
|
||||
)
|
||||
freeze_params(params_to_freeze)
|
||||
########################################################
|
||||
|
||||
@@ -1691,7 +1691,8 @@ def main(args):
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one.train()
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
|
||||
_te_one = unwrap_model(text_encoder_one)
|
||||
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
models_to_accumulate = [transformer]
|
||||
|
||||
@@ -1740,9 +1740,12 @@ def main(args):
|
||||
prompt_embeds = prompt_embeds_cache[step]
|
||||
text_ids = text_ids_cache[step]
|
||||
else:
|
||||
num_repeat_elements = len(prompts)
|
||||
prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1)
|
||||
text_ids = text_ids.repeat(num_repeat_elements, 1, 1)
|
||||
# With prior preservation, prompt_embeds/text_ids already contain [instance, class] entries,
|
||||
# while collate_fn orders batches as [inst1..instB, class1..classB]. Repeat each entry along
|
||||
# dim 0 to preserve that grouping instead of interleaving [inst, class, inst, class, ...].
|
||||
num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts)
|
||||
prompt_embeds = prompt_embeds.repeat_interleave(num_repeat_elements, dim=0)
|
||||
text_ids = text_ids.repeat_interleave(num_repeat_elements, dim=0)
|
||||
|
||||
# Convert images to latent space
|
||||
if args.cache_latents:
|
||||
@@ -1809,10 +1812,11 @@ def main(args):
|
||||
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
|
||||
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
|
||||
target, target_prior = torch.chunk(target, 2, dim=0)
|
||||
weighting, weighting_prior = torch.chunk(weighting, 2, dim=0)
|
||||
|
||||
# Compute prior loss
|
||||
prior_loss = torch.mean(
|
||||
(weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
|
||||
(weighting_prior.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
|
||||
target_prior.shape[0], -1
|
||||
),
|
||||
1,
|
||||
|
||||
@@ -1680,9 +1680,12 @@ def main(args):
|
||||
prompt_embeds = prompt_embeds_cache[step]
|
||||
text_ids = text_ids_cache[step]
|
||||
else:
|
||||
num_repeat_elements = len(prompts)
|
||||
prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1)
|
||||
text_ids = text_ids.repeat(num_repeat_elements, 1, 1)
|
||||
# With prior preservation, prompt_embeds/text_ids already contain [instance, class] entries,
|
||||
# while collate_fn orders batches as [inst1..instB, class1..classB]. Repeat each entry along
|
||||
# dim 0 to preserve that grouping instead of interleaving [inst, class, inst, class, ...].
|
||||
num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts)
|
||||
prompt_embeds = prompt_embeds.repeat_interleave(num_repeat_elements, dim=0)
|
||||
text_ids = text_ids.repeat_interleave(num_repeat_elements, dim=0)
|
||||
|
||||
# Convert images to latent space
|
||||
if args.cache_latents:
|
||||
@@ -1752,10 +1755,11 @@ def main(args):
|
||||
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
|
||||
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
|
||||
target, target_prior = torch.chunk(target, 2, dim=0)
|
||||
weighting, weighting_prior = torch.chunk(weighting, 2, dim=0)
|
||||
|
||||
# Compute prior loss
|
||||
prior_loss = torch.mean(
|
||||
(weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
|
||||
(weighting_prior.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
|
||||
target_prior.shape[0], -1
|
||||
),
|
||||
1,
|
||||
|
||||
@@ -1896,7 +1896,8 @@ def main(args):
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one.train()
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
|
||||
_te_one = unwrap_model(text_encoder_one)
|
||||
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
models_to_accumulate = [transformer]
|
||||
|
||||
@@ -1719,8 +1719,10 @@ def main(args):
|
||||
text_encoder_two.train()
|
||||
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
|
||||
accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)
|
||||
_te_one = accelerator.unwrap_model(text_encoder_one)
|
||||
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)
|
||||
_te_two = accelerator.unwrap_model(text_encoder_two)
|
||||
(_te_two.text_model if hasattr(_te_two, "text_model") else _te_two).embeddings.requires_grad_(True)
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
models_to_accumulate = [transformer]
|
||||
|
||||
@@ -1661,8 +1661,10 @@ def main(args):
|
||||
text_encoder_two.train()
|
||||
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
|
||||
accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)
|
||||
_te_one = accelerator.unwrap_model(text_encoder_one)
|
||||
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)
|
||||
_te_two = accelerator.unwrap_model(text_encoder_two)
|
||||
(_te_two.text_model if hasattr(_te_two, "text_model") else _te_two).embeddings.requires_grad_(True)
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(unet):
|
||||
|
||||
@@ -45,7 +45,16 @@ def annotate_pipeline(pipe):
|
||||
method = getattr(component, method_name, None)
|
||||
if method is None:
|
||||
continue
|
||||
setattr(component, method_name, annotate(method, label))
|
||||
|
||||
# Apply fix ONLY for LTX2 pipelines
|
||||
if "LTX2" in pipe.__class__.__name__:
|
||||
func = getattr(method, "__func__", method)
|
||||
wrapped = annotate(func, label)
|
||||
bound_method = wrapped.__get__(component, type(component))
|
||||
setattr(component, method_name, bound_method)
|
||||
else:
|
||||
# keep original behavior for other pipelines
|
||||
setattr(component, method_name, annotate(method, label))
|
||||
|
||||
# Annotate pipeline-level methods
|
||||
if hasattr(pipe, "encode_prompt"):
|
||||
|
||||
@@ -702,9 +702,10 @@ def main():
|
||||
vae.requires_grad_(False)
|
||||
unet.requires_grad_(False)
|
||||
# Freeze all parameters except for the token embeddings in text encoder
|
||||
text_encoder.text_model.encoder.requires_grad_(False)
|
||||
text_encoder.text_model.final_layer_norm.requires_grad_(False)
|
||||
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
|
||||
text_module = text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
text_module.encoder.requires_grad_(False)
|
||||
text_module.final_layer_norm.requires_grad_(False)
|
||||
text_module.embeddings.position_embedding.requires_grad_(False)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
# Keep unet in train mode if we are using gradient checkpointing to save memory.
|
||||
|
||||
@@ -717,12 +717,14 @@ def main():
|
||||
unet.requires_grad_(False)
|
||||
|
||||
# Freeze all parameters except for the token embeddings in text encoder
|
||||
text_encoder_1.text_model.encoder.requires_grad_(False)
|
||||
text_encoder_1.text_model.final_layer_norm.requires_grad_(False)
|
||||
text_encoder_1.text_model.embeddings.position_embedding.requires_grad_(False)
|
||||
text_encoder_2.text_model.encoder.requires_grad_(False)
|
||||
text_encoder_2.text_model.final_layer_norm.requires_grad_(False)
|
||||
text_encoder_2.text_model.embeddings.position_embedding.requires_grad_(False)
|
||||
text_module_1 = text_encoder_1.text_model if hasattr(text_encoder_1, "text_model") else text_encoder_1
|
||||
text_module_1.encoder.requires_grad_(False)
|
||||
text_module_1.final_layer_norm.requires_grad_(False)
|
||||
text_module_1.embeddings.position_embedding.requires_grad_(False)
|
||||
text_module_2 = text_encoder_2.text_model if hasattr(text_encoder_2, "text_model") else text_encoder_2
|
||||
text_module_2.encoder.requires_grad_(False)
|
||||
text_module_2.final_layer_norm.requires_grad_(False)
|
||||
text_module_2.embeddings.position_embedding.requires_grad_(False)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
text_encoder_1.gradient_checkpointing_enable()
|
||||
@@ -767,8 +769,12 @@ def main():
|
||||
optimizer = optimizer_class(
|
||||
# only optimize the embeddings
|
||||
[
|
||||
text_encoder_1.text_model.embeddings.token_embedding.weight,
|
||||
text_encoder_2.text_model.embeddings.token_embedding.weight,
|
||||
(
|
||||
text_encoder_1.text_model if hasattr(text_encoder_1, "text_model") else text_encoder_1
|
||||
).embeddings.token_embedding.weight,
|
||||
(
|
||||
text_encoder_2.text_model if hasattr(text_encoder_2, "text_model") else text_encoder_2
|
||||
).embeddings.token_embedding.weight,
|
||||
],
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
|
||||
224
scripts/convert_longcat_audio_dit_to_diffusers.py
Normal file
224
scripts/convert_longcat_audio_dit_to_diffusers.py
Normal file
@@ -0,0 +1,224 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Usage:
|
||||
# python scripts/convert_longcat_audio_dit_to_diffusers.py --checkpoint_path /path/to/model --output_path /data/models
|
||||
# python scripts/convert_longcat_audio_dit_to_diffusers.py --repo_id meituan-longcat/LongCat-AudioDiT-1B --output_path /data/models
|
||||
# python scripts/convert_longcat_audio_dit_to_diffusers.py --checkpoint_path /path/to/model --output_path /data/models --dtype fp16
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from safetensors.torch import load_file
|
||||
from transformers import AutoTokenizer, UMT5Config, UMT5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
LongCatAudioDiTPipeline,
|
||||
LongCatAudioDiTTransformer,
|
||||
LongCatAudioDiTVae,
|
||||
)
|
||||
|
||||
|
||||
def find_checkpoint(input_dir: Path):
|
||||
safetensors_file = input_dir / "model.safetensors"
|
||||
if safetensors_file.exists():
|
||||
return input_dir, safetensors_file
|
||||
|
||||
index_file = input_dir / "model.safetensors.index.json"
|
||||
if index_file.exists():
|
||||
with open(index_file) as f:
|
||||
index = json.load(f)
|
||||
weight_map = index.get("weight_map", {})
|
||||
first_weight = list(weight_map.values())[0]
|
||||
return input_dir, input_dir / first_weight
|
||||
|
||||
for subdir in input_dir.iterdir():
|
||||
if subdir.is_dir():
|
||||
safetensors_file = subdir / "model.safetensors"
|
||||
if safetensors_file.exists():
|
||||
return subdir, safetensors_file
|
||||
index_file = subdir / "model.safetensors.index.json"
|
||||
if index_file.exists():
|
||||
with open(index_file) as f:
|
||||
index = json.load(f)
|
||||
weight_map = index.get("weight_map", {})
|
||||
first_weight = list(weight_map.values())[0]
|
||||
return subdir, subdir / first_weight
|
||||
|
||||
raise FileNotFoundError(f"No checkpoint found in {input_dir}")
|
||||
|
||||
|
||||
def convert_longcat_audio_dit(
|
||||
checkpoint_path: str | None = None,
|
||||
repo_id: str | None = None,
|
||||
output_path: str = "",
|
||||
dtype: str = "fp32",
|
||||
text_encoder_model: str = "google/umt5-xxl",
|
||||
):
|
||||
if not checkpoint_path and not repo_id:
|
||||
raise ValueError("Either --checkpoint_path or --repo_id must be provided")
|
||||
if checkpoint_path and repo_id:
|
||||
raise ValueError("Cannot specify both --checkpoint_path and --repo_id")
|
||||
|
||||
dtype_map = {
|
||||
"fp32": torch.float32,
|
||||
"fp16": torch.float16,
|
||||
"bf16": torch.bfloat16,
|
||||
}
|
||||
torch_dtype = dtype_map.get(dtype, torch.float32)
|
||||
|
||||
if repo_id:
|
||||
input_dir = Path(snapshot_download(repo_id, local_files_only=False))
|
||||
model_name = repo_id.split("/")[-1]
|
||||
else:
|
||||
input_dir = Path(checkpoint_path)
|
||||
if not input_dir.exists():
|
||||
raise FileNotFoundError(f"Checkpoint path not found: {checkpoint_path}")
|
||||
model_name = None
|
||||
|
||||
model_dir, checkpoint_path = find_checkpoint(input_dir)
|
||||
if model_name is None:
|
||||
model_name = model_dir.name
|
||||
|
||||
config_path = model_dir / "config.json"
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"config.json not found in {model_dir}")
|
||||
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
state_dict = load_file(checkpoint_path)
|
||||
|
||||
transformer_keys = [k for k in state_dict.keys() if k.startswith("transformer.")]
|
||||
transformer_state_dict = {key[12:]: state_dict[key] for key in transformer_keys}
|
||||
|
||||
vae_keys = [k for k in state_dict.keys() if k.startswith("vae.")]
|
||||
vae_state_dict = {key[4:]: state_dict[key] for key in vae_keys}
|
||||
|
||||
text_encoder_keys = [k for k in state_dict.keys() if k.startswith("text_encoder.")]
|
||||
text_encoder_state_dict = {key[13:]: state_dict[key] for key in text_encoder_keys}
|
||||
|
||||
transformer = LongCatAudioDiTTransformer(
|
||||
dit_dim=config["dit_dim"],
|
||||
dit_depth=config["dit_depth"],
|
||||
dit_heads=config["dit_heads"],
|
||||
dit_text_dim=config["dit_text_dim"],
|
||||
latent_dim=config["latent_dim"],
|
||||
dropout=config.get("dit_dropout", 0.0),
|
||||
bias=config.get("dit_bias", True),
|
||||
cross_attn=config.get("dit_cross_attn", True),
|
||||
adaln_type=config.get("dit_adaln_type", "global"),
|
||||
adaln_use_text_cond=config.get("dit_adaln_use_text_cond", True),
|
||||
long_skip=config.get("dit_long_skip", True),
|
||||
text_conv=config.get("dit_text_conv", True),
|
||||
qk_norm=config.get("dit_qk_norm", True),
|
||||
cross_attn_norm=config.get("dit_cross_attn_norm", False),
|
||||
eps=config.get("dit_eps", 1e-6),
|
||||
use_latent_condition=config.get("dit_use_latent_condition", True),
|
||||
)
|
||||
transformer.load_state_dict(transformer_state_dict, strict=True)
|
||||
transformer = transformer.to(dtype=torch_dtype)
|
||||
|
||||
vae_config = dict(config["vae_config"])
|
||||
vae_config.pop("model_type", None)
|
||||
vae = LongCatAudioDiTVae(**vae_config)
|
||||
vae.load_state_dict(vae_state_dict, strict=True)
|
||||
vae = vae.to(dtype=torch_dtype)
|
||||
|
||||
text_encoder_config = UMT5Config.from_dict(config["text_encoder_config"])
|
||||
text_encoder = UMT5EncoderModel(text_encoder_config)
|
||||
text_missing, text_unexpected = text_encoder.load_state_dict(text_encoder_state_dict, strict=False)
|
||||
|
||||
allowed_missing = {"shared.weight"}
|
||||
unexpected_missing = set(text_missing) - allowed_missing
|
||||
if unexpected_missing:
|
||||
raise RuntimeError(f"Unexpected missing text encoder weights: {sorted(unexpected_missing)}")
|
||||
if text_unexpected:
|
||||
raise RuntimeError(f"Unexpected text encoder weights: {sorted(text_unexpected)}")
|
||||
if "shared.weight" in text_missing:
|
||||
text_encoder.shared.weight.data.copy_(text_encoder.encoder.embed_tokens.weight.data)
|
||||
|
||||
text_encoder = text_encoder.to(dtype=torch_dtype)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(text_encoder_model)
|
||||
|
||||
scheduler_config = {"shift": 1.0, "invert_sigmas": True}
|
||||
scheduler_config.update(config.get("scheduler_config", {}))
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(**scheduler_config)
|
||||
|
||||
pipeline = LongCatAudioDiTPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
pipeline.sample_rate = config.get("sampling_rate", 24000)
|
||||
pipeline.vae_scale_factor = config.get("vae_scale_factor", config.get("latent_hop", 2048))
|
||||
pipeline.max_wav_duration = config.get("max_wav_duration", 30.0)
|
||||
pipeline.text_norm_feat = config.get("text_norm_feat", True)
|
||||
pipeline.text_add_embed = config.get("text_add_embed", True)
|
||||
|
||||
output_path = Path(output_path) / f"{model_name}-Diffusers"
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
pipeline.save_pretrained(output_path)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--checkpoint_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to local model directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo_id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="HuggingFace repo_id to download model",
|
||||
)
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Output directory")
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
default="fp32",
|
||||
choices=["fp32", "fp16", "bf16"],
|
||||
help="Data type for converted weights",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_encoder_model",
|
||||
type=str,
|
||||
default="google/umt5-xxl",
|
||||
help="HuggingFace model ID for text encoder tokenizer",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
convert_longcat_audio_dit(
|
||||
checkpoint_path=args.checkpoint_path,
|
||||
repo_id=args.repo_id,
|
||||
output_path=args.output_path,
|
||||
dtype=args.dtype,
|
||||
text_encoder_model=args.text_encoder_model,
|
||||
)
|
||||
@@ -254,6 +254,8 @@ else:
|
||||
"Kandinsky3UNet",
|
||||
"Kandinsky5Transformer3DModel",
|
||||
"LatteTransformer3DModel",
|
||||
"LongCatAudioDiTTransformer",
|
||||
"LongCatAudioDiTVae",
|
||||
"LongCatImageTransformer2DModel",
|
||||
"LTX2VideoTransformer3DModel",
|
||||
"LTXVideoTransformer3DModel",
|
||||
@@ -599,6 +601,7 @@ else:
|
||||
"LEditsPPPipelineStableDiffusionXL",
|
||||
"LLaDA2Pipeline",
|
||||
"LLaDA2PipelineOutput",
|
||||
"LongCatAudioDiTPipeline",
|
||||
"LongCatImageEditPipeline",
|
||||
"LongCatImagePipeline",
|
||||
"LTX2ConditionPipeline",
|
||||
@@ -1058,6 +1061,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Kandinsky3UNet,
|
||||
Kandinsky5Transformer3DModel,
|
||||
LatteTransformer3DModel,
|
||||
LongCatAudioDiTTransformer,
|
||||
LongCatAudioDiTVae,
|
||||
LongCatImageTransformer2DModel,
|
||||
LTX2VideoTransformer3DModel,
|
||||
LTXVideoTransformer3DModel,
|
||||
@@ -1378,6 +1383,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
LLaDA2Pipeline,
|
||||
LLaDA2PipelineOutput,
|
||||
LongCatAudioDiTPipeline,
|
||||
LongCatImageEditPipeline,
|
||||
LongCatImagePipeline,
|
||||
LTX2ConditionPipeline,
|
||||
|
||||
@@ -50,6 +50,7 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.autoencoder_kl_qwenimage"] = ["AutoencoderKLQwenImage"]
|
||||
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
|
||||
_import_structure["autoencoders.autoencoder_kl_wan"] = ["AutoencoderKLWan"]
|
||||
_import_structure["autoencoders.autoencoder_longcat_audio_dit"] = ["LongCatAudioDiTVae"]
|
||||
_import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
|
||||
_import_structure["autoencoders.autoencoder_rae"] = ["AutoencoderRAE"]
|
||||
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
|
||||
@@ -112,6 +113,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"]
|
||||
_import_structure["transformers.transformer_longcat_audio_dit"] = ["LongCatAudioDiTTransformer"]
|
||||
_import_structure["transformers.transformer_longcat_image"] = ["LongCatImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_ltx2"] = ["LTX2VideoTransformer3DModel"]
|
||||
@@ -180,6 +182,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderTiny,
|
||||
AutoencoderVidTok,
|
||||
ConsistencyDecoderVAE,
|
||||
LongCatAudioDiTVae,
|
||||
VQModel,
|
||||
)
|
||||
from .cache_utils import CacheMixin
|
||||
@@ -233,6 +236,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
HunyuanVideoTransformer3DModel,
|
||||
Kandinsky5Transformer3DModel,
|
||||
LatteTransformer3DModel,
|
||||
LongCatAudioDiTTransformer,
|
||||
LongCatImageTransformer2DModel,
|
||||
LTX2VideoTransformer3DModel,
|
||||
LTXVideoTransformer3DModel,
|
||||
|
||||
@@ -19,6 +19,7 @@ from .autoencoder_kl_mochi import AutoencoderKLMochi
|
||||
from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage
|
||||
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
|
||||
from .autoencoder_kl_wan import AutoencoderKLWan
|
||||
from .autoencoder_longcat_audio_dit import LongCatAudioDiTVae
|
||||
from .autoencoder_oobleck import AutoencoderOobleck
|
||||
from .autoencoder_rae import AutoencoderRAE
|
||||
from .autoencoder_tiny import AutoencoderTiny
|
||||
|
||||
@@ -180,7 +180,7 @@ class QwenImageResample(nn.Module):
|
||||
feat_cache[idx] = "Rep"
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
cache_x = x[:, :, -min(CACHE_T, x.shape[2]) :, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat(
|
||||
@@ -258,7 +258,7 @@ class QwenImageResidualBlock(nn.Module):
|
||||
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
cache_x = x[:, :, -min(CACHE_T, x.shape[2]) :, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||
|
||||
@@ -277,7 +277,7 @@ class QwenImageResidualBlock(nn.Module):
|
||||
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
cache_x = x[:, :, -min(CACHE_T, x.shape[2]) :, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||
|
||||
@@ -446,7 +446,7 @@ class QwenImageEncoder3d(nn.Module):
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
cache_x = x[:, :, -min(CACHE_T, x.shape[2]) :, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||
@@ -471,7 +471,7 @@ class QwenImageEncoder3d(nn.Module):
|
||||
x = self.nonlinearity(x)
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
cache_x = x[:, :, -min(CACHE_T, x.shape[2]) :, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||
@@ -636,7 +636,7 @@ class QwenImageDecoder3d(nn.Module):
|
||||
## conv1
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
cache_x = x[:, :, -min(CACHE_T, x.shape[2]) :, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||
@@ -658,7 +658,7 @@ class QwenImageDecoder3d(nn.Module):
|
||||
x = self.nonlinearity(x)
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
cache_x = x[:, :, -min(CACHE_T, x.shape[2]) :, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||
|
||||
@@ -0,0 +1,400 @@
|
||||
# Copyright 2026 MeiTuan LongCat-AudioDiT Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Adapted from the LongCat-AudioDiT reference implementation:
|
||||
# https://github.com/meituan-longcat/LongCat-AudioDiT
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import BaseOutput
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import AutoencoderMixin
|
||||
|
||||
|
||||
def _wn_conv1d(in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, bias=True):
|
||||
return weight_norm(nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias))
|
||||
|
||||
|
||||
def _wn_conv_transpose1d(*args, **kwargs):
|
||||
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
||||
|
||||
|
||||
class Snake1d(nn.Module):
|
||||
def __init__(self, channels: int, alpha_logscale: bool = True):
|
||||
super().__init__()
|
||||
self.alpha_logscale = alpha_logscale
|
||||
self.alpha = nn.Parameter(torch.zeros(channels))
|
||||
self.beta = nn.Parameter(torch.zeros(channels))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
alpha = self.alpha[None, :, None]
|
||||
beta = self.beta[None, :, None]
|
||||
if self.alpha_logscale:
|
||||
alpha = torch.exp(alpha)
|
||||
beta = torch.exp(beta)
|
||||
return hidden_states + (1.0 / (beta + 1e-9)) * torch.sin(hidden_states * alpha).pow(2)
|
||||
|
||||
|
||||
def _get_vae_activation(name: str, channels: int = 0) -> nn.Module:
|
||||
if name == "elu":
|
||||
act = nn.ELU()
|
||||
elif name == "snake":
|
||||
act = Snake1d(channels)
|
||||
else:
|
||||
raise ValueError(f"Unknown activation: {name}")
|
||||
return act
|
||||
|
||||
|
||||
def _pixel_shuffle_1d(hidden_states: torch.Tensor, factor: int) -> torch.Tensor:
|
||||
batch, channels, width = hidden_states.size()
|
||||
return (
|
||||
hidden_states.view(batch, channels // factor, factor, width)
|
||||
.permute(0, 1, 3, 2)
|
||||
.contiguous()
|
||||
.view(batch, channels // factor, width * factor)
|
||||
)
|
||||
|
||||
|
||||
class DownsampleShortcut(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, factor: int):
|
||||
super().__init__()
|
||||
self.factor = factor
|
||||
self.group_size = in_channels * factor // out_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch, channels, width = hidden_states.shape
|
||||
hidden_states = (
|
||||
hidden_states.view(batch, channels, width // self.factor, self.factor)
|
||||
.permute(0, 1, 3, 2)
|
||||
.contiguous()
|
||||
.view(batch, channels * self.factor, width // self.factor)
|
||||
)
|
||||
return hidden_states.view(batch, self.out_channels, self.group_size, width // self.factor).mean(dim=2)
|
||||
|
||||
|
||||
class UpsampleShortcut(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, factor: int):
|
||||
super().__init__()
|
||||
self.factor = factor
|
||||
self.repeats = out_channels * factor // in_channels
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = hidden_states.repeat_interleave(self.repeats, dim=1)
|
||||
return _pixel_shuffle_1d(hidden_states, self.factor)
|
||||
|
||||
|
||||
class VaeResidualUnit(nn.Module):
|
||||
def __init__(
|
||||
self, in_channels: int, out_channels: int, dilation: int, kernel_size: int = 7, act_fn: str = "snake"
|
||||
):
|
||||
super().__init__()
|
||||
padding = (dilation * (kernel_size - 1)) // 2
|
||||
self.layers = nn.Sequential(
|
||||
_get_vae_activation(act_fn, channels=out_channels),
|
||||
_wn_conv1d(in_channels, out_channels, kernel_size, dilation=dilation, padding=padding),
|
||||
_get_vae_activation(act_fn, channels=out_channels),
|
||||
_wn_conv1d(out_channels, out_channels, kernel_size=1),
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
return hidden_states + self.layers(hidden_states)
|
||||
|
||||
|
||||
class VaeEncoderBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
stride: int,
|
||||
act_fn: str = "snake",
|
||||
downsample_shortcut: str = "none",
|
||||
):
|
||||
super().__init__()
|
||||
layers = [
|
||||
VaeResidualUnit(in_channels, in_channels, dilation=1, act_fn=act_fn),
|
||||
VaeResidualUnit(in_channels, in_channels, dilation=3, act_fn=act_fn),
|
||||
VaeResidualUnit(in_channels, in_channels, dilation=9, act_fn=act_fn),
|
||||
]
|
||||
layers.append(_get_vae_activation(act_fn, channels=in_channels))
|
||||
layers.append(
|
||||
_wn_conv1d(in_channels, out_channels, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2))
|
||||
)
|
||||
self.layers = nn.Sequential(*layers)
|
||||
self.residual = (
|
||||
DownsampleShortcut(in_channels, out_channels, stride) if downsample_shortcut == "averaging" else None
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
output_hidden_states = self.layers(hidden_states)
|
||||
if self.residual is not None:
|
||||
residual = self.residual(hidden_states)
|
||||
output_hidden_states = output_hidden_states + residual
|
||||
return output_hidden_states
|
||||
|
||||
|
||||
class VaeDecoderBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
stride: int,
|
||||
act_fn: str = "snake",
|
||||
upsample_shortcut: str = "none",
|
||||
):
|
||||
super().__init__()
|
||||
layers = [
|
||||
_get_vae_activation(act_fn, channels=in_channels),
|
||||
_wn_conv_transpose1d(
|
||||
in_channels, out_channels, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2)
|
||||
),
|
||||
VaeResidualUnit(out_channels, out_channels, dilation=1, act_fn=act_fn),
|
||||
VaeResidualUnit(out_channels, out_channels, dilation=3, act_fn=act_fn),
|
||||
VaeResidualUnit(out_channels, out_channels, dilation=9, act_fn=act_fn),
|
||||
]
|
||||
self.layers = nn.Sequential(*layers)
|
||||
self.residual = (
|
||||
UpsampleShortcut(in_channels, out_channels, stride) if upsample_shortcut == "duplicating" else None
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
output_hidden_states = self.layers(hidden_states)
|
||||
if self.residual is not None:
|
||||
residual = self.residual(hidden_states)
|
||||
output_hidden_states = output_hidden_states + residual
|
||||
return output_hidden_states
|
||||
|
||||
|
||||
class AudioDiTVaeEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 1,
|
||||
channels: int = 128,
|
||||
c_mults: list[int] | None = None,
|
||||
strides: list[int] | None = None,
|
||||
latent_dim: int = 64,
|
||||
encoder_latent_dim: int = 128,
|
||||
act_fn: str = "snake",
|
||||
downsample_shortcut: str = "averaging",
|
||||
out_shortcut: str = "averaging",
|
||||
):
|
||||
super().__init__()
|
||||
c_mults = [1] + (c_mults or [1, 2, 4, 8, 16])
|
||||
strides = list(strides or [2] * (len(c_mults) - 1))
|
||||
if len(strides) < len(c_mults) - 1:
|
||||
strides.extend([strides[-1] if strides else 2] * (len(c_mults) - 1 - len(strides)))
|
||||
else:
|
||||
strides = strides[: len(c_mults) - 1]
|
||||
channels_base = channels
|
||||
layers = [_wn_conv1d(in_channels, c_mults[0] * channels_base, kernel_size=7, padding=3)]
|
||||
for idx in range(len(c_mults) - 1):
|
||||
layers.append(
|
||||
VaeEncoderBlock(
|
||||
c_mults[idx] * channels_base,
|
||||
c_mults[idx + 1] * channels_base,
|
||||
strides[idx],
|
||||
act_fn=act_fn,
|
||||
downsample_shortcut=downsample_shortcut,
|
||||
)
|
||||
)
|
||||
layers.append(_wn_conv1d(c_mults[-1] * channels_base, encoder_latent_dim, kernel_size=3, padding=1))
|
||||
self.layers = nn.Sequential(*layers)
|
||||
self.shortcut = (
|
||||
DownsampleShortcut(c_mults[-1] * channels_base, encoder_latent_dim, 1)
|
||||
if out_shortcut == "averaging"
|
||||
else None
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.layers[:-1](hidden_states)
|
||||
output_hidden_states = self.layers[-1](hidden_states)
|
||||
if self.shortcut is not None:
|
||||
shortcut = self.shortcut(hidden_states)
|
||||
output_hidden_states = output_hidden_states + shortcut
|
||||
return output_hidden_states
|
||||
|
||||
|
||||
class AudioDiTVaeDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 1,
|
||||
channels: int = 128,
|
||||
c_mults: list[int] | None = None,
|
||||
strides: list[int] | None = None,
|
||||
latent_dim: int = 64,
|
||||
act_fn: str = "snake",
|
||||
in_shortcut: str = "duplicating",
|
||||
final_tanh: bool = False,
|
||||
upsample_shortcut: str = "duplicating",
|
||||
):
|
||||
super().__init__()
|
||||
c_mults = [1] + (c_mults or [1, 2, 4, 8, 16])
|
||||
strides = list(strides or [2] * (len(c_mults) - 1))
|
||||
if len(strides) < len(c_mults) - 1:
|
||||
strides.extend([strides[-1] if strides else 2] * (len(c_mults) - 1 - len(strides)))
|
||||
else:
|
||||
strides = strides[: len(c_mults) - 1]
|
||||
channels_base = channels
|
||||
|
||||
self.shortcut = (
|
||||
UpsampleShortcut(latent_dim, c_mults[-1] * channels_base, 1) if in_shortcut == "duplicating" else None
|
||||
)
|
||||
|
||||
layers = [_wn_conv1d(latent_dim, c_mults[-1] * channels_base, kernel_size=7, padding=3)]
|
||||
for idx in range(len(c_mults) - 1, 0, -1):
|
||||
layers.append(
|
||||
VaeDecoderBlock(
|
||||
c_mults[idx] * channels_base,
|
||||
c_mults[idx - 1] * channels_base,
|
||||
strides[idx - 1],
|
||||
act_fn=act_fn,
|
||||
upsample_shortcut=upsample_shortcut,
|
||||
)
|
||||
)
|
||||
layers.append(_get_vae_activation(act_fn, channels=c_mults[0] * channels_base))
|
||||
layers.append(_wn_conv1d(c_mults[0] * channels_base, in_channels, kernel_size=7, padding=3, bias=False))
|
||||
layers.append(nn.Tanh() if final_tanh else nn.Identity())
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if self.shortcut is None:
|
||||
return self.layers(hidden_states)
|
||||
hidden_states = self.shortcut(hidden_states) + self.layers[0](hidden_states)
|
||||
return self.layers[1:](hidden_states)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LongCatAudioDiTVaeEncoderOutput(BaseOutput):
|
||||
latents: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class LongCatAudioDiTVaeDecoderOutput(BaseOutput):
|
||||
sample: torch.Tensor
|
||||
|
||||
|
||||
class LongCatAudioDiTVae(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
_supports_group_offloading = False
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 1,
|
||||
channels: int = 128,
|
||||
c_mults: list[int] | None = None,
|
||||
strides: list[int] | None = None,
|
||||
latent_dim: int = 64,
|
||||
encoder_latent_dim: int = 128,
|
||||
act_fn: str | None = None,
|
||||
use_snake: bool | None = None,
|
||||
downsample_shortcut: str = "averaging",
|
||||
upsample_shortcut: str = "duplicating",
|
||||
out_shortcut: str = "averaging",
|
||||
in_shortcut: str = "duplicating",
|
||||
final_tanh: bool = False,
|
||||
downsampling_ratio: int = 2048,
|
||||
sample_rate: int = 24000,
|
||||
scale: float = 0.71,
|
||||
):
|
||||
super().__init__()
|
||||
if act_fn is None:
|
||||
if use_snake is None:
|
||||
act_fn = "snake"
|
||||
else:
|
||||
act_fn = "snake" if use_snake else "elu"
|
||||
self.encoder = AudioDiTVaeEncoder(
|
||||
in_channels=in_channels,
|
||||
channels=channels,
|
||||
c_mults=c_mults,
|
||||
strides=strides,
|
||||
latent_dim=latent_dim,
|
||||
encoder_latent_dim=encoder_latent_dim,
|
||||
act_fn=act_fn,
|
||||
downsample_shortcut=downsample_shortcut,
|
||||
out_shortcut=out_shortcut,
|
||||
)
|
||||
self.decoder = AudioDiTVaeDecoder(
|
||||
in_channels=in_channels,
|
||||
channels=channels,
|
||||
c_mults=c_mults,
|
||||
strides=strides,
|
||||
latent_dim=latent_dim,
|
||||
act_fn=act_fn,
|
||||
in_shortcut=in_shortcut,
|
||||
final_tanh=final_tanh,
|
||||
upsample_shortcut=upsample_shortcut,
|
||||
)
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
sample_posterior: bool = True,
|
||||
return_dict: bool = True,
|
||||
generator: torch.Generator | None = None,
|
||||
) -> LongCatAudioDiTVaeEncoderOutput | tuple[torch.Tensor]:
|
||||
encoder_dtype = next(self.encoder.parameters()).dtype
|
||||
if sample.dtype != encoder_dtype:
|
||||
sample = sample.to(encoder_dtype)
|
||||
encoded = self.encoder(sample)
|
||||
mean, scale_param = encoded.chunk(2, dim=1)
|
||||
std = F.softplus(scale_param) + 1e-4
|
||||
if sample_posterior:
|
||||
noise = randn_tensor(mean.shape, generator=generator, device=mean.device, dtype=mean.dtype)
|
||||
latents = mean + std * noise
|
||||
else:
|
||||
latents = mean
|
||||
latents = latents / self.config.scale
|
||||
if encoder_dtype != torch.float32:
|
||||
latents = latents.float()
|
||||
if not return_dict:
|
||||
return (latents,)
|
||||
return LongCatAudioDiTVaeEncoderOutput(latents=latents)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(
|
||||
self, latents: torch.Tensor, return_dict: bool = True
|
||||
) -> LongCatAudioDiTVaeDecoderOutput | tuple[torch.Tensor]:
|
||||
decoder_dtype = next(self.decoder.parameters()).dtype
|
||||
latents = latents * self.config.scale
|
||||
if latents.dtype != decoder_dtype:
|
||||
latents = latents.to(decoder_dtype)
|
||||
decoded = self.decoder(latents)
|
||||
if decoder_dtype != torch.float32:
|
||||
decoded = decoded.float()
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
return LongCatAudioDiTVaeDecoderOutput(sample=decoded)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: torch.Generator | None = None,
|
||||
) -> LongCatAudioDiTVaeDecoderOutput | tuple[torch.Tensor]:
|
||||
latents = self.encode(sample, sample_posterior=sample_posterior, return_dict=True, generator=generator).latents
|
||||
decoded = self.decode(latents, return_dict=True).sample
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
return LongCatAudioDiTVaeDecoderOutput(sample=decoded)
|
||||
@@ -36,6 +36,7 @@ if is_torch_available():
|
||||
from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel
|
||||
from .transformer_hunyuanimage import HunyuanImageTransformer2DModel
|
||||
from .transformer_kandinsky import Kandinsky5Transformer3DModel
|
||||
from .transformer_longcat_audio_dit import LongCatAudioDiTTransformer
|
||||
from .transformer_longcat_image import LongCatImageTransformer2DModel
|
||||
from .transformer_ltx import LTXVideoTransformer3DModel
|
||||
from .transformer_ltx2 import LTX2VideoTransformer3DModel
|
||||
|
||||
@@ -0,0 +1,605 @@
|
||||
# Copyright 2026 MeiTuan LongCat-AudioDiT Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Adapted from the LongCat-AudioDiT reference implementation:
|
||||
# https://github.com/meituan-longcat/LongCat-AudioDiT
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import BaseOutput
|
||||
from ...utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph
|
||||
from ..attention import AttentionModuleMixin
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import RMSNorm
|
||||
|
||||
|
||||
@dataclass
|
||||
class LongCatAudioDiTTransformerOutput(BaseOutput):
|
||||
sample: torch.Tensor
|
||||
|
||||
|
||||
class AudioDiTSinusPositionEmbedding(nn.Module):
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, timesteps: torch.Tensor, scale: float = 1000.0) -> torch.Tensor:
|
||||
device = timesteps.device
|
||||
half_dim = self.dim // 2
|
||||
exponent = math.log(10000) / max(half_dim - 1, 1)
|
||||
embeddings = torch.exp(torch.arange(half_dim, device=device).float() * -exponent)
|
||||
embeddings = scale * timesteps.unsqueeze(1) * embeddings.unsqueeze(0)
|
||||
return torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
|
||||
|
||||
|
||||
class AudioDiTTimestepEmbedding(nn.Module):
|
||||
def __init__(self, dim: int, freq_embed_dim: int = 256):
|
||||
super().__init__()
|
||||
self.time_embed = AudioDiTSinusPositionEmbedding(freq_embed_dim)
|
||||
self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
||||
|
||||
def forward(self, timestep: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.time_embed(timestep)
|
||||
return self.time_mlp(hidden_states.to(timestep.dtype))
|
||||
|
||||
|
||||
class AudioDiTRotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim: int, max_position_embeddings: int = 2048, base: float = 100000.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
|
||||
@lru_cache_unless_export(maxsize=128)
|
||||
def _build(self, seq_len: int, device: torch.device | None = None) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
|
||||
if device is not None:
|
||||
inv_freq = inv_freq.to(device)
|
||||
steps = torch.arange(seq_len, dtype=torch.int64, device=inv_freq.device).type_as(inv_freq)
|
||||
freqs = torch.outer(steps, inv_freq)
|
||||
embeddings = torch.cat((freqs, freqs), dim=-1)
|
||||
return embeddings.cos().contiguous(), embeddings.sin().contiguous()
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, seq_len: int | None = None) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
seq_len = hidden_states.shape[1] if seq_len is None else seq_len
|
||||
cos, sin = self._build(max(seq_len, self.max_position_embeddings), hidden_states.device)
|
||||
return cos[:seq_len].to(dtype=hidden_states.dtype), sin[:seq_len].to(dtype=hidden_states.dtype)
|
||||
|
||||
|
||||
def _rotate_half(hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
first, second = hidden_states.chunk(2, dim=-1)
|
||||
return torch.cat((-second, first), dim=-1)
|
||||
|
||||
|
||||
def _apply_rotary_emb(hidden_states: torch.Tensor, rope: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
|
||||
cos, sin = rope
|
||||
cos = cos[None, :, None].to(hidden_states.device)
|
||||
sin = sin[None, :, None].to(hidden_states.device)
|
||||
return (hidden_states.float() * cos + _rotate_half(hidden_states).float() * sin).to(hidden_states.dtype)
|
||||
|
||||
|
||||
class AudioDiTGRN(nn.Module):
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
|
||||
self.beta = nn.Parameter(torch.zeros(1, 1, dim))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
gx = torch.norm(hidden_states, p=2, dim=1, keepdim=True)
|
||||
nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6)
|
||||
return self.gamma * (hidden_states * nx) + self.beta + hidden_states
|
||||
|
||||
|
||||
class AudioDiTConvNeXtV2Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
intermediate_dim: int,
|
||||
dilation: int = 1,
|
||||
kernel_size: int = 7,
|
||||
bias: bool = True,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
padding = (dilation * (kernel_size - 1)) // 2
|
||||
self.dwconv = nn.Conv1d(
|
||||
dim, dim, kernel_size=kernel_size, padding=padding, groups=dim, dilation=dilation, bias=bias
|
||||
)
|
||||
self.norm = nn.LayerNorm(dim, eps=eps)
|
||||
self.pwconv1 = nn.Linear(dim, intermediate_dim, bias=bias)
|
||||
self.act = nn.SiLU()
|
||||
self.grn = AudioDiTGRN(intermediate_dim)
|
||||
self.pwconv2 = nn.Linear(intermediate_dim, dim, bias=bias)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.dwconv(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = self.pwconv1(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.grn(hidden_states)
|
||||
hidden_states = self.pwconv2(hidden_states)
|
||||
return residual + hidden_states
|
||||
|
||||
|
||||
class AudioDiTEmbedder(nn.Module):
|
||||
def __init__(self, in_dim: int, out_dim: int):
|
||||
super().__init__()
|
||||
self.proj = nn.Sequential(nn.Linear(in_dim, out_dim), nn.SiLU(), nn.Linear(out_dim, out_dim))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, mask: torch.BoolTensor | None = None) -> torch.Tensor:
|
||||
if mask is not None:
|
||||
hidden_states = hidden_states.masked_fill(mask.logical_not().unsqueeze(-1), 0.0)
|
||||
hidden_states = self.proj(hidden_states)
|
||||
if mask is not None:
|
||||
hidden_states = hidden_states.masked_fill(mask.logical_not().unsqueeze(-1), 0.0)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AudioDiTAdaLNMLP(nn.Module):
|
||||
def __init__(self, in_dim: int, out_dim: int, bias: bool = True):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(nn.SiLU(), nn.Linear(in_dim, out_dim, bias=bias))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
return self.mlp(hidden_states)
|
||||
|
||||
|
||||
class AudioDiTAdaLayerNormZeroFinal(nn.Module):
|
||||
def __init__(self, dim: int, bias: bool = True, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(dim, dim * 2, bias=bias)
|
||||
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, embedding: torch.Tensor) -> torch.Tensor:
|
||||
embedding = self.linear(self.silu(embedding))
|
||||
scale, shift = torch.chunk(embedding, 2, dim=-1)
|
||||
hidden_states = self.norm(hidden_states.float()).type_as(hidden_states)
|
||||
if scale.ndim == 2:
|
||||
hidden_states = hidden_states * (1 + scale)[:, None, :] + shift[:, None, :]
|
||||
else:
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AudioDiTSelfAttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "AudioDiTAttention",
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.BoolTensor | None = None,
|
||||
audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size = hidden_states.shape[0]
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
if attn.qk_norm:
|
||||
query = attn.q_norm(query)
|
||||
key = attn.k_norm(key)
|
||||
|
||||
head_dim = attn.inner_dim // attn.heads
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim)
|
||||
|
||||
if audio_rotary_emb is not None:
|
||||
query = _apply_rotary_emb(query, audio_rotary_emb)
|
||||
key = _apply_rotary_emb(key, audio_rotary_emb)
|
||||
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
if attention_mask is not None:
|
||||
hidden_states = hidden_states * attention_mask[:, :, None, None].to(hidden_states.dtype)
|
||||
|
||||
hidden_states = hidden_states.flatten(2, 3).to(query.dtype)
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AudioDiTAttention(nn.Module, AttentionModuleMixin):
|
||||
def __init__(
|
||||
self,
|
||||
q_dim: int,
|
||||
kv_dim: int | None,
|
||||
heads: int,
|
||||
dim_head: int,
|
||||
dropout: float = 0.0,
|
||||
bias: bool = True,
|
||||
qk_norm: bool = False,
|
||||
eps: float = 1e-6,
|
||||
processor: AttentionModuleMixin | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
kv_dim = q_dim if kv_dim is None else kv_dim
|
||||
self.heads = heads
|
||||
self.inner_dim = dim_head * heads
|
||||
self.to_q = nn.Linear(q_dim, self.inner_dim, bias=bias)
|
||||
self.to_k = nn.Linear(kv_dim, self.inner_dim, bias=bias)
|
||||
self.to_v = nn.Linear(kv_dim, self.inner_dim, bias=bias)
|
||||
self.qk_norm = qk_norm
|
||||
if qk_norm:
|
||||
self.q_norm = RMSNorm(self.inner_dim, eps=eps)
|
||||
self.k_norm = RMSNorm(self.inner_dim, eps=eps)
|
||||
self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, q_dim, bias=bias), nn.Dropout(dropout)])
|
||||
self.set_processor(processor or AudioDiTSelfAttnProcessor())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor | None = None,
|
||||
post_attention_mask: torch.BoolTensor | None = None,
|
||||
attention_mask: torch.BoolTensor | None = None,
|
||||
audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
prompt_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> torch.Tensor:
|
||||
if encoder_hidden_states is None:
|
||||
return self.processor(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
audio_rotary_emb=audio_rotary_emb,
|
||||
)
|
||||
return self.processor(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
post_attention_mask=post_attention_mask,
|
||||
attention_mask=attention_mask,
|
||||
audio_rotary_emb=audio_rotary_emb,
|
||||
prompt_rotary_emb=prompt_rotary_emb,
|
||||
)
|
||||
|
||||
|
||||
class AudioDiTCrossAttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "AudioDiTAttention",
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
post_attention_mask: torch.BoolTensor | None = None,
|
||||
attention_mask: torch.BoolTensor | None = None,
|
||||
audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
prompt_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size = hidden_states.shape[0]
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
if attn.qk_norm:
|
||||
query = attn.q_norm(query)
|
||||
key = attn.k_norm(key)
|
||||
|
||||
head_dim = attn.inner_dim // attn.heads
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim)
|
||||
|
||||
if audio_rotary_emb is not None:
|
||||
query = _apply_rotary_emb(query, audio_rotary_emb)
|
||||
if prompt_rotary_emb is not None:
|
||||
key = _apply_rotary_emb(key, prompt_rotary_emb)
|
||||
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
if post_attention_mask is not None:
|
||||
hidden_states = hidden_states * post_attention_mask[:, :, None, None].to(hidden_states.dtype)
|
||||
|
||||
hidden_states = hidden_states.flatten(2, 3).to(query.dtype)
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AudioDiTFeedForward(nn.Module):
|
||||
def __init__(self, dim: int, mult: float = 4.0, dropout: float = 0.0, bias: bool = True):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
self.ff = nn.Sequential(
|
||||
nn.Linear(dim, inner_dim, bias=bias),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim, bias=bias),
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
return self.ff(hidden_states)
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class AudioDiTBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
cond_dim: int,
|
||||
heads: int,
|
||||
dim_head: int,
|
||||
dropout: float = 0.0,
|
||||
bias: bool = True,
|
||||
qk_norm: bool = False,
|
||||
eps: float = 1e-6,
|
||||
cross_attn: bool = True,
|
||||
cross_attn_norm: bool = False,
|
||||
adaln_type: str = "global",
|
||||
adaln_use_text_cond: bool = True,
|
||||
ff_mult: float = 4.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.adaln_type = adaln_type
|
||||
self.adaln_use_text_cond = adaln_use_text_cond
|
||||
if adaln_type == "local":
|
||||
self.adaln_mlp = AudioDiTAdaLNMLP(dim, dim * 6, bias=True)
|
||||
elif adaln_type == "global":
|
||||
self.adaln_scale_shift = nn.Parameter(torch.randn(dim * 6) / dim**0.5)
|
||||
|
||||
self.self_attn = AudioDiTAttention(
|
||||
dim, None, heads, dim_head, dropout=dropout, bias=bias, qk_norm=qk_norm, eps=eps
|
||||
)
|
||||
|
||||
self.use_cross_attn = cross_attn
|
||||
if cross_attn:
|
||||
self.cross_attn = AudioDiTAttention(
|
||||
dim,
|
||||
cond_dim,
|
||||
heads,
|
||||
dim_head,
|
||||
dropout=dropout,
|
||||
bias=bias,
|
||||
qk_norm=qk_norm,
|
||||
eps=eps,
|
||||
processor=AudioDiTCrossAttnProcessor(),
|
||||
)
|
||||
self.cross_attn_norm = (
|
||||
nn.LayerNorm(dim, elementwise_affine=True, eps=eps) if cross_attn_norm else nn.Identity()
|
||||
)
|
||||
self.cross_attn_norm_c = (
|
||||
nn.LayerNorm(cond_dim, elementwise_affine=True, eps=eps) if cross_attn_norm else nn.Identity()
|
||||
)
|
||||
self.ffn = AudioDiTFeedForward(dim=dim, mult=ff_mult, dropout=dropout, bias=bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep_embed: torch.Tensor,
|
||||
cond: torch.Tensor,
|
||||
mask: torch.BoolTensor | None = None,
|
||||
cond_mask: torch.BoolTensor | None = None,
|
||||
rope: tuple | None = None,
|
||||
cond_rope: tuple | None = None,
|
||||
adaln_global_out: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if self.adaln_type == "local" and adaln_global_out is None:
|
||||
if self.adaln_use_text_cond:
|
||||
denom = cond_mask.sum(1, keepdim=True).clamp(min=1).to(cond.dtype)
|
||||
cond_mean = cond.sum(1) / denom
|
||||
norm_cond = timestep_embed + cond_mean
|
||||
else:
|
||||
norm_cond = timestep_embed
|
||||
adaln_out = self.adaln_mlp(norm_cond)
|
||||
gate_sa, scale_sa, shift_sa, gate_ffn, scale_ffn, shift_ffn = torch.chunk(adaln_out, 6, dim=-1)
|
||||
else:
|
||||
adaln_out = adaln_global_out + self.adaln_scale_shift.unsqueeze(0)
|
||||
gate_sa, scale_sa, shift_sa, gate_ffn, scale_ffn, shift_ffn = torch.chunk(adaln_out, 6, dim=-1)
|
||||
|
||||
norm_hidden_states = F.layer_norm(hidden_states.float(), (hidden_states.shape[-1],), eps=1e-6).type_as(
|
||||
hidden_states
|
||||
)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_sa[:, None]) + shift_sa[:, None]
|
||||
attn_output = self.self_attn(
|
||||
norm_hidden_states,
|
||||
attention_mask=mask,
|
||||
audio_rotary_emb=rope,
|
||||
)
|
||||
hidden_states = hidden_states + gate_sa.unsqueeze(1) * attn_output
|
||||
|
||||
if self.use_cross_attn:
|
||||
cross_output = self.cross_attn(
|
||||
hidden_states=self.cross_attn_norm(hidden_states),
|
||||
encoder_hidden_states=self.cross_attn_norm_c(cond),
|
||||
post_attention_mask=mask,
|
||||
attention_mask=cond_mask,
|
||||
audio_rotary_emb=rope,
|
||||
prompt_rotary_emb=cond_rope,
|
||||
)
|
||||
hidden_states = hidden_states + cross_output
|
||||
|
||||
norm_hidden_states = F.layer_norm(hidden_states.float(), (hidden_states.shape[-1],), eps=1e-6).type_as(
|
||||
hidden_states
|
||||
)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_ffn[:, None]) + shift_ffn[:, None]
|
||||
ff_output = self.ffn(norm_hidden_states)
|
||||
hidden_states = hidden_states + gate_ffn.unsqueeze(1) * ff_output
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LongCatAudioDiTTransformer(ModelMixin, ConfigMixin):
|
||||
_supports_gradient_checkpointing = False
|
||||
_repeated_blocks = ["AudioDiTBlock"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
dit_dim: int = 1536,
|
||||
dit_depth: int = 24,
|
||||
dit_heads: int = 24,
|
||||
dit_text_dim: int = 768,
|
||||
latent_dim: int = 64,
|
||||
dropout: float = 0.0,
|
||||
bias: bool = True,
|
||||
cross_attn: bool = True,
|
||||
adaln_type: str = "global",
|
||||
adaln_use_text_cond: bool = True,
|
||||
long_skip: bool = True,
|
||||
text_conv: bool = True,
|
||||
qk_norm: bool = True,
|
||||
cross_attn_norm: bool = False,
|
||||
eps: float = 1e-6,
|
||||
use_latent_condition: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
dim = dit_dim
|
||||
dim_head = dim // dit_heads
|
||||
self.time_embed = AudioDiTTimestepEmbedding(dim)
|
||||
self.input_embed = AudioDiTEmbedder(latent_dim, dim)
|
||||
self.text_embed = AudioDiTEmbedder(dit_text_dim, dim)
|
||||
self.rotary_embed = AudioDiTRotaryEmbedding(dim_head, 2048, base=100000.0)
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
AudioDiTBlock(
|
||||
dim=dim,
|
||||
cond_dim=dim,
|
||||
heads=dit_heads,
|
||||
dim_head=dim_head,
|
||||
dropout=dropout,
|
||||
bias=bias,
|
||||
qk_norm=qk_norm,
|
||||
eps=eps,
|
||||
cross_attn=cross_attn,
|
||||
cross_attn_norm=cross_attn_norm,
|
||||
adaln_type=adaln_type,
|
||||
adaln_use_text_cond=adaln_use_text_cond,
|
||||
ff_mult=4.0,
|
||||
)
|
||||
for _ in range(dit_depth)
|
||||
]
|
||||
)
|
||||
self.norm_out = AudioDiTAdaLayerNormZeroFinal(dim, bias=bias, eps=eps)
|
||||
self.proj_out = nn.Linear(dim, latent_dim)
|
||||
if adaln_type == "global":
|
||||
self.adaln_global_mlp = AudioDiTAdaLNMLP(dim, dim * 6, bias=True)
|
||||
self.text_conv = text_conv
|
||||
if text_conv:
|
||||
self.text_conv_layer = nn.Sequential(
|
||||
*[AudioDiTConvNeXtV2Block(dim, dim * 2, bias=bias, eps=eps) for _ in range(4)]
|
||||
)
|
||||
self.use_latent_condition = use_latent_condition
|
||||
if use_latent_condition:
|
||||
self.latent_embed = AudioDiTEmbedder(latent_dim, dim)
|
||||
self.latent_cond_embedder = AudioDiTEmbedder(dim * 2, dim)
|
||||
self._initialize_weights(bias=bias)
|
||||
|
||||
def _initialize_weights(self, bias: bool = True):
|
||||
if self.config.adaln_type == "local":
|
||||
for block in self.blocks:
|
||||
nn.init.constant_(block.adaln_mlp.mlp[-1].weight, 0)
|
||||
if bias:
|
||||
nn.init.constant_(block.adaln_mlp.mlp[-1].bias, 0)
|
||||
elif self.config.adaln_type == "global":
|
||||
nn.init.constant_(self.adaln_global_mlp.mlp[-1].weight, 0)
|
||||
if bias:
|
||||
nn.init.constant_(self.adaln_global_mlp.mlp[-1].bias, 0)
|
||||
nn.init.constant_(self.norm_out.linear.weight, 0)
|
||||
nn.init.constant_(self.proj_out.weight, 0)
|
||||
if bias:
|
||||
nn.init.constant_(self.norm_out.linear.bias, 0)
|
||||
nn.init.constant_(self.proj_out.bias, 0)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_attention_mask: torch.BoolTensor,
|
||||
timestep: torch.Tensor,
|
||||
attention_mask: torch.BoolTensor | None = None,
|
||||
latent_cond: torch.Tensor | None = None,
|
||||
return_dict: bool = True,
|
||||
) -> LongCatAudioDiTTransformerOutput | tuple[torch.Tensor]:
|
||||
dtype = hidden_states.dtype
|
||||
encoder_hidden_states = encoder_hidden_states.to(dtype)
|
||||
timestep = timestep.to(dtype)
|
||||
batch_size = hidden_states.shape[0]
|
||||
if timestep.ndim == 0:
|
||||
timestep = timestep.repeat(batch_size)
|
||||
timestep_embed = self.time_embed(timestep)
|
||||
text_mask = encoder_attention_mask.bool()
|
||||
encoder_hidden_states = self.text_embed(encoder_hidden_states, text_mask)
|
||||
if self.text_conv:
|
||||
encoder_hidden_states = self.text_conv_layer(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.masked_fill(text_mask.logical_not().unsqueeze(-1), 0.0)
|
||||
hidden_states = self.input_embed(hidden_states, attention_mask)
|
||||
if self.use_latent_condition and latent_cond is not None:
|
||||
latent_cond = self.latent_embed(latent_cond.to(hidden_states.dtype), attention_mask)
|
||||
hidden_states = self.latent_cond_embedder(torch.cat([hidden_states, latent_cond], dim=-1))
|
||||
residual = hidden_states.clone() if self.config.long_skip else None
|
||||
rope = self.rotary_embed(hidden_states, hidden_states.shape[1])
|
||||
cond_rope = self.rotary_embed(encoder_hidden_states, encoder_hidden_states.shape[1])
|
||||
if self.config.adaln_type == "global":
|
||||
if self.config.adaln_use_text_cond:
|
||||
text_len = text_mask.sum(1).clamp(min=1).to(encoder_hidden_states.dtype)
|
||||
text_mean = encoder_hidden_states.sum(1) / text_len.unsqueeze(1)
|
||||
norm_cond = timestep_embed + text_mean
|
||||
else:
|
||||
norm_cond = timestep_embed
|
||||
adaln_global_out = self.adaln_global_mlp(norm_cond)
|
||||
for block in self.blocks:
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
timestep_embed=timestep_embed,
|
||||
cond=encoder_hidden_states,
|
||||
mask=attention_mask,
|
||||
cond_mask=text_mask,
|
||||
rope=rope,
|
||||
cond_rope=cond_rope,
|
||||
adaln_global_out=adaln_global_out,
|
||||
)
|
||||
else:
|
||||
norm_cond = timestep_embed
|
||||
for block in self.blocks:
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
timestep_embed=timestep_embed,
|
||||
cond=encoder_hidden_states,
|
||||
mask=attention_mask,
|
||||
cond_mask=text_mask,
|
||||
rope=rope,
|
||||
cond_rope=cond_rope,
|
||||
)
|
||||
if self.config.long_skip:
|
||||
hidden_states = hidden_states + residual
|
||||
hidden_states = self.norm_out(hidden_states, norm_cond)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
if attention_mask is not None:
|
||||
hidden_states = hidden_states * attention_mask.unsqueeze(-1).to(hidden_states.dtype)
|
||||
if not return_dict:
|
||||
return (hidden_states,)
|
||||
return LongCatAudioDiTTransformerOutput(sample=hidden_states)
|
||||
@@ -326,6 +326,7 @@ else:
|
||||
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
|
||||
_import_structure["lucy"] = ["LucyEditPipeline"]
|
||||
_import_structure["longcat_image"] = ["LongCatImagePipeline", "LongCatImageEditPipeline"]
|
||||
_import_structure["longcat_audio_dit"] = ["LongCatAudioDiTPipeline"]
|
||||
_import_structure["marigold"].extend(
|
||||
[
|
||||
"MarigoldDepthPipeline",
|
||||
@@ -753,6 +754,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
)
|
||||
from .llada2 import LLaDA2Pipeline, LLaDA2PipelineOutput
|
||||
from .longcat_audio_dit import LongCatAudioDiTPipeline
|
||||
from .longcat_image import LongCatImageEditPipeline, LongCatImagePipeline
|
||||
from .ltx import (
|
||||
LTXConditionPipeline,
|
||||
|
||||
40
src/diffusers/pipelines/longcat_audio_dit/__init__.py
Normal file
40
src/diffusers/pipelines/longcat_audio_dit/__init__.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa: F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_longcat_audio_dit"] = ["LongCatAudioDiTPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_longcat_audio_dit import LongCatAudioDiTPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -0,0 +1,332 @@
|
||||
# Copyright 2026 MeiTuan LongCat-AudioDiT Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Adapted from the LongCat-AudioDiT reference implementation:
|
||||
# https://github.com/meituan-longcat/LongCat-AudioDiT
|
||||
|
||||
import re
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import PreTrainedTokenizerBase, UMT5EncoderModel
|
||||
|
||||
from ...models import LongCatAudioDiTTransformer, LongCatAudioDiTVae
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _lens_to_mask(lengths: torch.Tensor, length: int | None = None) -> torch.BoolTensor:
|
||||
if length is None:
|
||||
length = int(lengths.amax().item())
|
||||
seq = torch.arange(length, device=lengths.device)
|
||||
return seq[None, :] < lengths[:, None]
|
||||
|
||||
|
||||
def _normalize_text(text: str) -> str:
|
||||
text = text.lower()
|
||||
text = re.sub(r'["“”‘’]', " ", text)
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
def _approx_duration_from_text(text: str | list[str], max_duration: float = 30.0) -> float:
|
||||
if not text:
|
||||
return 0.0
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
|
||||
en_dur_per_char = 0.082
|
||||
zh_dur_per_char = 0.21
|
||||
durations = []
|
||||
for prompt in text:
|
||||
prompt = re.sub(r"\s+", "", prompt)
|
||||
num_zh = num_en = num_other = 0
|
||||
for char in prompt:
|
||||
if "一" <= char <= "鿿":
|
||||
num_zh += 1
|
||||
elif char.isalpha():
|
||||
num_en += 1
|
||||
else:
|
||||
num_other += 1
|
||||
if num_zh > num_en:
|
||||
num_zh += num_other
|
||||
else:
|
||||
num_en += num_other
|
||||
durations.append(num_zh * zh_dur_per_char + num_en * en_dur_per_char)
|
||||
return min(max_duration, max(durations)) if durations else 0.0
|
||||
|
||||
|
||||
class LongCatAudioDiTPipeline(DiffusionPipeline):
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: LongCatAudioDiTVae,
|
||||
text_encoder: UMT5EncoderModel,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
transformer: LongCatAudioDiTTransformer,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
if not isinstance(scheduler, FlowMatchEulerDiscreteScheduler):
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(shift=1.0, invert_sigmas=True)
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.sample_rate = getattr(vae.config, "sample_rate", 24000)
|
||||
self.vae_scale_factor = getattr(vae.config, "downsampling_ratio", 2048)
|
||||
self.latent_dim = getattr(transformer.config, "latent_dim", 64)
|
||||
self.max_wav_duration = 30.0
|
||||
self.text_norm_feat = True
|
||||
self.text_add_embed = True
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
def encode_prompt(self, prompt: str | list[str], device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
model_max_length = getattr(self.tokenizer, "model_max_length", 512)
|
||||
if not isinstance(model_max_length, int) or model_max_length <= 0 or model_max_length > 32768:
|
||||
model_max_length = 512
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="longest",
|
||||
truncation=True,
|
||||
max_length=model_max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
input_ids = text_inputs.input_ids.to(device)
|
||||
attention_mask = text_inputs.attention_mask.to(device)
|
||||
with torch.no_grad():
|
||||
output = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
|
||||
prompt_embeds = output.last_hidden_state
|
||||
if self.text_norm_feat:
|
||||
prompt_embeds = F.layer_norm(prompt_embeds, (prompt_embeds.shape[-1],), eps=1e-6)
|
||||
if self.text_add_embed and getattr(output, "hidden_states", None):
|
||||
first_hidden = output.hidden_states[0]
|
||||
if self.text_norm_feat:
|
||||
first_hidden = F.layer_norm(first_hidden, (first_hidden.shape[-1],), eps=1e-6)
|
||||
prompt_embeds = prompt_embeds + first_hidden
|
||||
lengths = attention_mask.sum(dim=1).to(device)
|
||||
return prompt_embeds, lengths
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size: int,
|
||||
duration: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
generator: torch.Generator | list[torch.Generator] | None = None,
|
||||
latents: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if latents is not None:
|
||||
if latents.ndim != 3:
|
||||
raise ValueError(
|
||||
f"`latents` must have shape (batch_size, duration, latent_dim), but got {tuple(latents.shape)}."
|
||||
)
|
||||
if latents.shape[0] != batch_size:
|
||||
raise ValueError(f"`latents` must have batch size {batch_size}, but got {latents.shape[0]}.")
|
||||
if latents.shape[2] != self.latent_dim:
|
||||
raise ValueError(f"`latents` must have latent_dim {self.latent_dim}, but got {latents.shape[2]}.")
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"Expected {batch_size} generators for batch size {batch_size}, but got {len(generator)}."
|
||||
)
|
||||
|
||||
return randn_tensor((batch_size, duration, self.latent_dim), generator=generator, device=device, dtype=dtype)
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt: list[str],
|
||||
negative_prompt: str | list[str] | None,
|
||||
output_type: str,
|
||||
callback_on_step_end_tensor_inputs: list[str] | None = None,
|
||||
) -> None:
|
||||
if len(prompt) == 0:
|
||||
raise ValueError("`prompt` must contain at least one prompt.")
|
||||
|
||||
if output_type not in {"np", "pt", "latent"}:
|
||||
raise ValueError(f"Unsupported output_type: {output_type}")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found "
|
||||
f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if negative_prompt is not None and not isinstance(negative_prompt, str):
|
||||
negative_prompt = list(negative_prompt)
|
||||
if len(negative_prompt) != len(prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt` must have batch size {len(prompt)}, but got {len(negative_prompt)} prompts."
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str | list[str],
|
||||
negative_prompt: str | list[str] | None = None,
|
||||
audio_duration_s: float | None = None,
|
||||
latents: torch.Tensor | None = None,
|
||||
num_inference_steps: int = 16,
|
||||
guidance_scale: float = 4.0,
|
||||
generator: torch.Generator | list[torch.Generator] | None = None,
|
||||
output_type: str = "np",
|
||||
return_dict: bool = True,
|
||||
callback_on_step_end: Callable[[int, int], None] | None = None,
|
||||
callback_on_step_end_tensor_inputs: list[str] = ["latents"],
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `list[str]`): Prompt or prompts that guide audio generation.
|
||||
negative_prompt (`str` or `list[str]`, *optional*): Negative prompt(s) for classifier-free guidance.
|
||||
audio_duration_s (`float`, *optional*):
|
||||
Target audio duration in seconds. Ignored when `latents` is provided.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents of shape `(batch_size, duration, latent_dim)`.
|
||||
num_inference_steps (`int`, defaults to 16): Number of denoising steps.
|
||||
guidance_scale (`float`, defaults to 4.0): Guidance scale for classifier-free guidance.
|
||||
generator (`torch.Generator` or `list[torch.Generator]`, *optional*): Random generator(s).
|
||||
output_type (`str`, defaults to `"np"`): Output format: `"np"`, `"pt"`, or `"latent"`.
|
||||
return_dict (`bool`, defaults to `True`): Whether to return `AudioPipelineOutput`.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function called at the end of each denoising step with the pipeline, step index, timestep, and tensor
|
||||
inputs specified by `callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`list`, defaults to `["latents"]`):
|
||||
Tensor inputs passed to `callback_on_step_end`.
|
||||
"""
|
||||
if prompt is None:
|
||||
prompt = []
|
||||
elif isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
else:
|
||||
prompt = list(prompt)
|
||||
self.check_inputs(prompt, negative_prompt, output_type, callback_on_step_end_tensor_inputs)
|
||||
batch_size = len(prompt)
|
||||
self._guidance_scale = guidance_scale
|
||||
|
||||
device = self._execution_device
|
||||
normalized_prompts = [_normalize_text(text) for text in prompt]
|
||||
if latents is not None:
|
||||
duration = latents.shape[1]
|
||||
elif audio_duration_s is not None:
|
||||
duration = int(audio_duration_s * self.sample_rate // self.vae_scale_factor)
|
||||
else:
|
||||
duration = int(_approx_duration_from_text(normalized_prompts) * self.sample_rate // self.vae_scale_factor)
|
||||
max_duration = int(self.max_wav_duration * self.sample_rate // self.vae_scale_factor)
|
||||
if latents is None:
|
||||
duration = max(1, min(duration, max_duration))
|
||||
|
||||
prompt_embeds, prompt_embeds_len = self.encode_prompt(normalized_prompts, device)
|
||||
duration_tensor = torch.full((batch_size,), duration, device=device, dtype=torch.long)
|
||||
mask = _lens_to_mask(duration_tensor)
|
||||
text_mask = _lens_to_mask(prompt_embeds_len, length=prompt_embeds.shape[1])
|
||||
|
||||
if negative_prompt is None:
|
||||
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
||||
negative_prompt_embeds_len = prompt_embeds_len
|
||||
negative_prompt_embeds_mask = text_mask
|
||||
else:
|
||||
if isinstance(negative_prompt, str):
|
||||
negative_prompt = [negative_prompt] * batch_size
|
||||
else:
|
||||
negative_prompt = list(negative_prompt)
|
||||
negative_prompt_embeds, negative_prompt_embeds_len = self.encode_prompt(negative_prompt, device)
|
||||
negative_prompt_embeds_mask = _lens_to_mask(
|
||||
negative_prompt_embeds_len, length=negative_prompt_embeds.shape[1]
|
||||
)
|
||||
|
||||
latent_cond = torch.zeros(batch_size, duration, self.latent_dim, device=device, dtype=prompt_embeds.dtype)
|
||||
latents = self.prepare_latents(
|
||||
batch_size, duration, device, prompt_embeds.dtype, generator=generator, latents=latents
|
||||
)
|
||||
if num_inference_steps < 1:
|
||||
raise ValueError("num_inference_steps must be a positive integer.")
|
||||
|
||||
sigmas = torch.linspace(1.0, 1.0 / num_inference_steps, num_inference_steps, dtype=torch.float32).tolist()
|
||||
self.scheduler.set_timesteps(sigmas=sigmas, device=device)
|
||||
self.scheduler.set_begin_index(0)
|
||||
timesteps = self.scheduler.timesteps
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
curr_t = (
|
||||
(t / self.scheduler.config.num_train_timesteps).expand(batch_size).to(dtype=prompt_embeds.dtype)
|
||||
)
|
||||
pred = self.transformer(
|
||||
hidden_states=latents,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_attention_mask=text_mask,
|
||||
timestep=curr_t,
|
||||
attention_mask=mask,
|
||||
latent_cond=latent_cond,
|
||||
).sample
|
||||
if self.guidance_scale > 1.0:
|
||||
null_pred = self.transformer(
|
||||
hidden_states=latents,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
encoder_attention_mask=negative_prompt_embeds_mask,
|
||||
timestep=curr_t,
|
||||
attention_mask=mask,
|
||||
latent_cond=latent_cond,
|
||||
).sample
|
||||
pred = null_pred + (pred - null_pred) * self.guidance_scale
|
||||
latents = self.scheduler.step(pred, t, latents, return_dict=False)[0]
|
||||
progress_bar.update()
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
if output_type == "latent":
|
||||
waveform = latents
|
||||
else:
|
||||
waveform = self.vae.decode(latents.permute(0, 2, 1)).sample
|
||||
if output_type == "np":
|
||||
waveform = waveform.cpu().float().numpy()
|
||||
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (waveform,)
|
||||
return AudioPipelineOutput(audios=waveform)
|
||||
@@ -1395,6 +1395,36 @@ class LatteTransformer3DModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class LongCatAudioDiTTransformer(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class LongCatAudioDiTVae(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class LongCatImageTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -2297,6 +2297,21 @@ class LLaDA2PipelineOutput(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LongCatAudioDiTPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LongCatImageEditPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -0,0 +1,121 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers import LongCatAudioDiTTransformer
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
MemoryTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class LongCatAudioDiTTransformerTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
@property
|
||||
def model_class(self):
|
||||
return LongCatAudioDiTTransformer
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple[int, ...]:
|
||||
return (16, 8)
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict[str, int | bool | float | str]:
|
||||
return {
|
||||
"dit_dim": 64,
|
||||
"dit_depth": 2,
|
||||
"dit_heads": 4,
|
||||
"dit_text_dim": 32,
|
||||
"latent_dim": 8,
|
||||
"text_conv": False,
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
batch_size = 1
|
||||
sequence_length = 16
|
||||
encoder_sequence_length = 10
|
||||
latent_dim = 8
|
||||
text_dim = 32
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, latent_dim), generator=self.generator, device=torch_device
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, encoder_sequence_length, text_dim), generator=self.generator, device=torch_device
|
||||
),
|
||||
"encoder_attention_mask": torch.ones(
|
||||
batch_size, encoder_sequence_length, dtype=torch.bool, device=torch_device
|
||||
),
|
||||
"attention_mask": torch.ones(batch_size, sequence_length, dtype=torch.bool, device=torch_device),
|
||||
"timestep": torch.ones(batch_size, device=torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestLongCatAudioDiTTransformer(LongCatAudioDiTTransformerTesterConfig, ModelTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
class TestLongCatAudioDiTTransformerMemory(LongCatAudioDiTTransformerTesterConfig, MemoryTesterMixin):
|
||||
def test_layerwise_casting_memory(self):
|
||||
pytest.skip(
|
||||
"LongCatAudioDiTTransformer tiny test config does not provide stable layerwise casting peak memory "
|
||||
"coverage."
|
||||
)
|
||||
|
||||
|
||||
class TestLongCatAudioDiTTransformerCompile(LongCatAudioDiTTransformerTesterConfig, TorchCompileTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
class TestLongCatAudioDiTTransformerAttention(LongCatAudioDiTTransformerTesterConfig, AttentionTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
def test_longcat_audio_attention_uses_standard_self_attn_kwargs():
|
||||
from diffusers.models.transformers.transformer_longcat_audio_dit import AudioDiTAttention
|
||||
|
||||
attn = AudioDiTAttention(q_dim=4, kv_dim=None, heads=1, dim_head=4, dropout=0.0, bias=False)
|
||||
|
||||
eye = torch.eye(4)
|
||||
with torch.no_grad():
|
||||
attn.to_q.weight.copy_(eye)
|
||||
attn.to_k.weight.copy_(eye)
|
||||
attn.to_v.weight.copy_(eye)
|
||||
attn.to_out[0].weight.copy_(eye)
|
||||
|
||||
hidden_states = torch.tensor([[[1.0, 0.0, 0.0, 0.0], [0.5, 0.5, 0.5, 0.5]]])
|
||||
attention_mask = torch.tensor([[True, False]])
|
||||
|
||||
output = attn(hidden_states=hidden_states, attention_mask=attention_mask)
|
||||
|
||||
assert torch.allclose(output[:, 1], torch.zeros_like(output[:, 1]))
|
||||
0
tests/pipelines/longcat_audio_dit/__init__.py
Normal file
0
tests/pipelines/longcat_audio_dit/__init__.py
Normal file
225
tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py
Normal file
225
tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py
Normal file
@@ -0,0 +1,225 @@
|
||||
# Copyright 2026 The HuggingFace Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, UMT5Config, UMT5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
LongCatAudioDiTPipeline,
|
||||
LongCatAudioDiTTransformer,
|
||||
LongCatAudioDiTVae,
|
||||
)
|
||||
|
||||
from ...testing_utils import enable_full_determinism, require_torch_accelerator, slow, torch_device
|
||||
from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class LongCatAudioDiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = LongCatAudioDiTPipeline
|
||||
params = (
|
||||
TEXT_TO_AUDIO_PARAMS
|
||||
- {"audio_length_in_s", "prompt_embeds", "negative_prompt_embeds", "cross_attention_kwargs"}
|
||||
) | {"audio_duration_s"}
|
||||
batch_params = TEXT_TO_AUDIO_BATCH_PARAMS
|
||||
required_optional_params = PipelineTesterMixin.required_optional_params - {"num_images_per_prompt"}
|
||||
test_attention_slicing = False
|
||||
test_xformers_attention = False
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder = UMT5EncoderModel(
|
||||
UMT5Config(d_model=32, num_layers=1, num_heads=4, d_ff=64, vocab_size=tokenizer.vocab_size)
|
||||
)
|
||||
transformer = LongCatAudioDiTTransformer(
|
||||
dit_dim=64,
|
||||
dit_depth=2,
|
||||
dit_heads=4,
|
||||
dit_text_dim=32,
|
||||
latent_dim=8,
|
||||
text_conv=False,
|
||||
)
|
||||
vae = LongCatAudioDiTVae(
|
||||
in_channels=1,
|
||||
channels=16,
|
||||
c_mults=[1, 2],
|
||||
strides=[2],
|
||||
latent_dim=8,
|
||||
encoder_latent_dim=16,
|
||||
downsampling_ratio=2,
|
||||
sample_rate=24000,
|
||||
)
|
||||
|
||||
return {
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"transformer": transformer,
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0, prompt="soft ocean ambience"):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"audio_duration_s": 0.1,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 1.0,
|
||||
"generator": generator,
|
||||
"output_type": "pt",
|
||||
}
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
pipe = self.pipeline_class(**self.get_dummy_components())
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
output = pipe(**self.get_dummy_inputs(device)).audios
|
||||
|
||||
self.assertEqual(output.ndim, 3)
|
||||
self.assertEqual(output.shape[0], 1)
|
||||
self.assertEqual(output.shape[1], 1)
|
||||
self.assertGreater(output.shape[-1], 0)
|
||||
|
||||
def test_save_load_local(self):
|
||||
import tempfile
|
||||
|
||||
device = "cpu"
|
||||
pipe = self.pipeline_class(**self.get_dummy_components())
|
||||
pipe.to(device)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipe.save_pretrained(tmp_dir)
|
||||
reloaded = self.pipeline_class.from_pretrained(tmp_dir, local_files_only=True)
|
||||
output = reloaded(**self.get_dummy_inputs(device, seed=0)).audios
|
||||
|
||||
self.assertIsInstance(reloaded, LongCatAudioDiTPipeline)
|
||||
self.assertEqual(output.ndim, 3)
|
||||
self.assertGreater(output.shape[-1], 0)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
|
||||
|
||||
def test_model_cpu_offload_forward_pass(self):
|
||||
self.skipTest(
|
||||
"LongCatAudioDiTPipeline offload coverage is not ready for the standard PipelineTesterMixin test."
|
||||
)
|
||||
|
||||
def test_cpu_offload_forward_pass_twice(self):
|
||||
self.skipTest(
|
||||
"LongCatAudioDiTPipeline offload coverage is not ready for the standard PipelineTesterMixin test."
|
||||
)
|
||||
|
||||
def test_sequential_cpu_offload_forward_pass(self):
|
||||
self.skipTest(
|
||||
"LongCatAudioDiTPipeline uses `torch.nn.utils.weight_norm`, which is not compatible with "
|
||||
"sequential offloading."
|
||||
)
|
||||
|
||||
def test_sequential_offload_forward_pass_twice(self):
|
||||
self.skipTest(
|
||||
"LongCatAudioDiTPipeline uses `torch.nn.utils.weight_norm`, which is not compatible with "
|
||||
"sequential offloading."
|
||||
)
|
||||
|
||||
def test_pipeline_level_group_offloading_inference(self):
|
||||
self.skipTest(
|
||||
"LongCatAudioDiTPipeline group offloading coverage is not ready for the standard PipelineTesterMixin test."
|
||||
)
|
||||
|
||||
def test_num_images_per_prompt(self):
|
||||
self.skipTest("LongCatAudioDiTPipeline does not support num_images_per_prompt.")
|
||||
|
||||
def test_encode_prompt_works_in_isolation(self):
|
||||
self.skipTest("LongCatAudioDiTPipeline.encode_prompt has a custom signature.")
|
||||
|
||||
def test_uniform_flow_match_scheduler_grid_matches_manual_updates(self):
|
||||
num_inference_steps = 6
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(shift=1.0, invert_sigmas=True)
|
||||
sigmas = torch.linspace(1.0, 1.0 / num_inference_steps, num_inference_steps, dtype=torch.float32).tolist()
|
||||
scheduler.set_timesteps(sigmas=sigmas, device="cpu")
|
||||
|
||||
expected_grid = torch.linspace(0, 1, num_inference_steps + 1, dtype=torch.float32)
|
||||
actual_timesteps = scheduler.timesteps / scheduler.config.num_train_timesteps
|
||||
self.assertTrue(torch.allclose(actual_timesteps, expected_grid[:-1], atol=1e-6, rtol=0))
|
||||
|
||||
sample = torch.zeros(1, 2, 3)
|
||||
model_output = torch.ones_like(sample)
|
||||
expected = sample.clone()
|
||||
for t0, t1, scheduler_t in zip(expected_grid[:-1], expected_grid[1:], scheduler.timesteps):
|
||||
expected = expected + model_output * (t1 - t0)
|
||||
sample = scheduler.step(model_output, scheduler_t, sample, return_dict=False)[0]
|
||||
|
||||
self.assertTrue(torch.allclose(sample, expected, atol=1e-6, rtol=0))
|
||||
|
||||
|
||||
def test_longcat_audio_top_level_imports():
|
||||
assert LongCatAudioDiTPipeline is not None
|
||||
assert LongCatAudioDiTTransformer is not None
|
||||
assert LongCatAudioDiTVae is not None
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_accelerator
|
||||
class LongCatAudioDiTPipelineSlowTests(unittest.TestCase):
|
||||
pipeline_class = LongCatAudioDiTPipeline
|
||||
|
||||
def test_longcat_audio_pipeline_from_pretrained_real_local_weights(self):
|
||||
model_path = Path(
|
||||
os.getenv("LONGCAT_AUDIO_DIT_MODEL_PATH", "/data/models/meituan-longcat/LongCat-AudioDiT-1B")
|
||||
)
|
||||
tokenizer_path_env = os.getenv("LONGCAT_AUDIO_DIT_TOKENIZER_PATH")
|
||||
if tokenizer_path_env is None:
|
||||
raise unittest.SkipTest("LONGCAT_AUDIO_DIT_TOKENIZER_PATH is not set")
|
||||
tokenizer_path = Path(tokenizer_path_env)
|
||||
|
||||
if not model_path.exists():
|
||||
raise unittest.SkipTest(f"LongCat-AudioDiT model path not found: {model_path}")
|
||||
if not tokenizer_path.exists():
|
||||
raise unittest.SkipTest(f"LongCat-AudioDiT tokenizer path not found: {tokenizer_path}")
|
||||
|
||||
pipe = LongCatAudioDiTPipeline.from_pretrained(
|
||||
model_path,
|
||||
tokenizer=tokenizer_path,
|
||||
torch_dtype=torch.float16,
|
||||
local_files_only=True,
|
||||
)
|
||||
pipe = pipe.to(torch_device)
|
||||
|
||||
result = pipe(
|
||||
prompt="A calm ocean wave ambience with soft wind in the background.",
|
||||
audio_duration_s=2.0,
|
||||
num_inference_steps=2,
|
||||
guidance_scale=4.0,
|
||||
output_type="pt",
|
||||
)
|
||||
|
||||
assert result.audios.ndim == 3
|
||||
assert result.audios.shape[0] == 1
|
||||
assert result.audios.shape[1] == 1
|
||||
assert result.audios.shape[-1] > 0
|
||||
Reference in New Issue
Block a user