mirror of
https://github.com/huggingface/diffusers.git
synced 2026-04-16 12:47:04 +08:00
Compare commits
12 Commits
docs/model
...
flux-spmd-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
323a7dbfa8 | ||
|
|
71a6fd9f0d | ||
|
|
cae90fd696 | ||
|
|
294fd1adec | ||
|
|
a68f3677b7 | ||
|
|
b99078d227 | ||
|
|
d30831683c | ||
|
|
bf1fd9d403 | ||
|
|
c41a3c3ed8 | ||
|
|
4689714b88 | ||
|
|
0d79fc2e60 | ||
|
|
e4d219b366 |
@@ -8,7 +8,7 @@ on:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@90b4ee2c10b81b5c1a6367c4e6fc9e2fb510a7e3 # main
|
||||
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@9ad2de8582b56c017cb530c1165116d40433f1c6 # main
|
||||
with:
|
||||
package_name: diffusers
|
||||
secrets:
|
||||
|
||||
@@ -490,6 +490,8 @@
|
||||
- sections:
|
||||
- local: api/pipelines/audioldm2
|
||||
title: AudioLDM 2
|
||||
- local: api/pipelines/longcat_audio_dit
|
||||
title: LongCat-AudioDiT
|
||||
- local: api/pipelines/stable_audio
|
||||
title: Stable Audio
|
||||
title: Audio
|
||||
|
||||
61
docs/source/en/api/pipelines/longcat_audio_dit.md
Normal file
61
docs/source/en/api/pipelines/longcat_audio_dit.md
Normal file
@@ -0,0 +1,61 @@
|
||||
<!--Copyright 2026 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# LongCat-AudioDiT
|
||||
|
||||
LongCat-AudioDiT is a text-to-audio diffusion model from Meituan LongCat. The diffusers integration exposes a standard [`DiffusionPipeline`] interface for text-conditioned audio generation.
|
||||
|
||||
This pipeline supports loading the original flat LongCat checkpoint layout from either a local directory or a Hugging Face Hub repository containing:
|
||||
|
||||
- `config.json`
|
||||
- `model.safetensors`
|
||||
|
||||
The loader builds the text encoder, transformer, and VAE from `config.json`, restores component weights from `model.safetensors`, and ties the shared UMT5 embedding when needed.
|
||||
|
||||
This pipeline was adapted from the LongCat-AudioDiT reference implementation: https://github.com/meituan-longcat/LongCat-AudioDiT
|
||||
|
||||
## Usage
|
||||
|
||||
```py
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from diffusers import LongCatAudioDiTPipeline
|
||||
|
||||
pipeline = LongCatAudioDiTPipeline.from_pretrained(
|
||||
"meituan-longcat/LongCat-AudioDiT-1B",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipeline = pipeline.to("cuda")
|
||||
|
||||
audio = pipeline(
|
||||
prompt="A calm ocean wave ambience with soft wind in the background.",
|
||||
audio_end_in_s=5.0,
|
||||
num_inference_steps=16,
|
||||
guidance_scale=4.0,
|
||||
output_type="pt",
|
||||
).audios
|
||||
|
||||
output = audio[0, 0].float().cpu().numpy()
|
||||
sf.write("longcat.wav", output, pipeline.sample_rate)
|
||||
```
|
||||
|
||||
## Tips
|
||||
|
||||
- `audio_end_in_s` is the most direct way to control output duration.
|
||||
- `output_type="pt"` returns a PyTorch tensor shaped `(batch, channels, samples)`.
|
||||
|
||||
## LongCatAudioDiTPipeline
|
||||
|
||||
[[autodoc]] LongCatAudioDiTPipeline
|
||||
- all
|
||||
- __call__
|
||||
- from_pretrained
|
||||
@@ -29,6 +29,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
|
||||
|---|---|
|
||||
| [AnimateDiff](animatediff) | text2video |
|
||||
| [AudioLDM2](audioldm2) | text2audio |
|
||||
| [LongCat-AudioDiT](longcat_audio_dit) | text2audio |
|
||||
| [AuraFlow](aura_flow) | text2image |
|
||||
| [Bria 3.2](bria_3_2) | text2image |
|
||||
| [CogVideoX](cogvideox) | text2video |
|
||||
|
||||
@@ -895,9 +895,8 @@ class TokenEmbeddingsHandler:
|
||||
self.train_ids_t5 = tokenizer.convert_tokens_to_ids(self.inserting_toks)
|
||||
|
||||
# random initialization of new tokens
|
||||
embeds = (
|
||||
text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens
|
||||
)
|
||||
text_module = text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
embeds = text_module.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens
|
||||
std_token_embedding = embeds.weight.data.std()
|
||||
|
||||
logger.info(f"{idx} text encoder's std_token_embedding: {std_token_embedding}")
|
||||
@@ -905,9 +904,7 @@ class TokenEmbeddingsHandler:
|
||||
train_ids = self.train_ids if idx == 0 else self.train_ids_t5
|
||||
# if initializer_concept are not provided, token embeddings are initialized randomly
|
||||
if args.initializer_concept is None:
|
||||
hidden_size = (
|
||||
text_encoder.text_model.config.hidden_size if idx == 0 else text_encoder.encoder.config.hidden_size
|
||||
)
|
||||
hidden_size = text_module.config.hidden_size if idx == 0 else text_encoder.encoder.config.hidden_size
|
||||
embeds.weight.data[train_ids] = (
|
||||
torch.randn(len(train_ids), hidden_size).to(device=self.device).to(dtype=self.dtype)
|
||||
* std_token_embedding
|
||||
@@ -940,7 +937,8 @@ class TokenEmbeddingsHandler:
|
||||
idx_to_text_encoder_name = {0: "clip_l", 1: "t5"}
|
||||
for idx, text_encoder in enumerate(self.text_encoders):
|
||||
train_ids = self.train_ids if idx == 0 else self.train_ids_t5
|
||||
embeds = text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared
|
||||
text_module = text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
embeds = text_module.embeddings.token_embedding if idx == 0 else text_encoder.shared
|
||||
assert embeds.weight.data.shape[0] == len(self.tokenizers[idx]), "Tokenizers should be the same."
|
||||
new_token_embeddings = embeds.weight.data[train_ids]
|
||||
|
||||
@@ -962,7 +960,8 @@ class TokenEmbeddingsHandler:
|
||||
@torch.no_grad()
|
||||
def retract_embeddings(self):
|
||||
for idx, text_encoder in enumerate(self.text_encoders):
|
||||
embeds = text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared
|
||||
text_module = text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
embeds = text_module.embeddings.token_embedding if idx == 0 else text_encoder.shared
|
||||
index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"]
|
||||
embeds.weight.data[index_no_updates] = (
|
||||
self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates]
|
||||
@@ -2112,7 +2111,8 @@ def main(args):
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one.train()
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
|
||||
_te_one = unwrap_model(text_encoder_one)
|
||||
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)
|
||||
elif args.train_text_encoder_ti: # textual inversion / pivotal tuning
|
||||
text_encoder_one.train()
|
||||
if args.enable_t5_ti:
|
||||
|
||||
@@ -763,19 +763,28 @@ class TokenEmbeddingsHandler:
|
||||
self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)
|
||||
|
||||
# random initialization of new tokens
|
||||
std_token_embedding = text_encoder.text_model.embeddings.token_embedding.weight.data.std()
|
||||
std_token_embedding = (
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data.std()
|
||||
|
||||
print(f"{idx} text encoder's std_token_embedding: {std_token_embedding}")
|
||||
|
||||
text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] = (
|
||||
torch.randn(len(self.train_ids), text_encoder.text_model.config.hidden_size)
|
||||
(
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data[self.train_ids] = (
|
||||
torch.randn(
|
||||
len(self.train_ids),
|
||||
(
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).config.hidden_size,
|
||||
)
|
||||
.to(device=self.device)
|
||||
.to(dtype=self.dtype)
|
||||
* std_token_embedding
|
||||
)
|
||||
self.embeddings_settings[f"original_embeddings_{idx}"] = (
|
||||
text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
|
||||
)
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data.clone()
|
||||
self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
|
||||
|
||||
inu = torch.ones((len(tokenizer),), dtype=torch.bool)
|
||||
@@ -794,10 +803,14 @@ class TokenEmbeddingsHandler:
|
||||
# text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14 - TODO - change for sd
|
||||
idx_to_text_encoder_name = {0: "clip_l", 1: "clip_g"}
|
||||
for idx, text_encoder in enumerate(self.text_encoders):
|
||||
assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len(
|
||||
self.tokenizers[0]
|
||||
), "Tokenizers should be the same."
|
||||
new_token_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids]
|
||||
assert (
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data.shape[0] == len(self.tokenizers[0]), (
|
||||
"Tokenizers should be the same."
|
||||
)
|
||||
new_token_embeddings = (
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data[self.train_ids]
|
||||
|
||||
# New tokens for each text encoder are saved under "clip_l" (for text_encoder 0), "clip_g" (for
|
||||
# text_encoder 1) to keep compatible with the ecosystem.
|
||||
@@ -819,7 +832,9 @@ class TokenEmbeddingsHandler:
|
||||
def retract_embeddings(self):
|
||||
for idx, text_encoder in enumerate(self.text_encoders):
|
||||
index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"]
|
||||
text_encoder.text_model.embeddings.token_embedding.weight.data[index_no_updates] = (
|
||||
(
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data[index_no_updates] = (
|
||||
self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates]
|
||||
.to(device=text_encoder.device)
|
||||
.to(dtype=text_encoder.dtype)
|
||||
@@ -830,11 +845,15 @@ class TokenEmbeddingsHandler:
|
||||
std_token_embedding = self.embeddings_settings[f"std_token_embedding_{idx}"]
|
||||
|
||||
index_updates = ~index_no_updates
|
||||
new_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates]
|
||||
new_embeddings = (
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data[index_updates]
|
||||
off_ratio = std_token_embedding / new_embeddings.std()
|
||||
|
||||
new_embeddings = new_embeddings * (off_ratio**0.1)
|
||||
text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] = new_embeddings
|
||||
(
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data[index_updates] = new_embeddings
|
||||
|
||||
|
||||
class DreamBoothDataset(Dataset):
|
||||
@@ -1704,7 +1723,8 @@ def main(args):
|
||||
text_encoder_one.train()
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one.text_model.embeddings.requires_grad_(True)
|
||||
_te_one = text_encoder_one
|
||||
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)
|
||||
|
||||
unet.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
|
||||
@@ -929,19 +929,28 @@ class TokenEmbeddingsHandler:
|
||||
self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)
|
||||
|
||||
# random initialization of new tokens
|
||||
std_token_embedding = text_encoder.text_model.embeddings.token_embedding.weight.data.std()
|
||||
std_token_embedding = (
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data.std()
|
||||
|
||||
print(f"{idx} text encoder's std_token_embedding: {std_token_embedding}")
|
||||
|
||||
text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] = (
|
||||
torch.randn(len(self.train_ids), text_encoder.text_model.config.hidden_size)
|
||||
(
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data[self.train_ids] = (
|
||||
torch.randn(
|
||||
len(self.train_ids),
|
||||
(
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).config.hidden_size,
|
||||
)
|
||||
.to(device=self.device)
|
||||
.to(dtype=self.dtype)
|
||||
* std_token_embedding
|
||||
)
|
||||
self.embeddings_settings[f"original_embeddings_{idx}"] = (
|
||||
text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
|
||||
)
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data.clone()
|
||||
self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
|
||||
|
||||
inu = torch.ones((len(tokenizer),), dtype=torch.bool)
|
||||
@@ -959,10 +968,14 @@ class TokenEmbeddingsHandler:
|
||||
# text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14
|
||||
idx_to_text_encoder_name = {0: "clip_l", 1: "clip_g"}
|
||||
for idx, text_encoder in enumerate(self.text_encoders):
|
||||
assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len(
|
||||
self.tokenizers[0]
|
||||
), "Tokenizers should be the same."
|
||||
new_token_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids]
|
||||
assert (
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data.shape[0] == len(self.tokenizers[0]), (
|
||||
"Tokenizers should be the same."
|
||||
)
|
||||
new_token_embeddings = (
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data[self.train_ids]
|
||||
|
||||
# New tokens for each text encoder are saved under "clip_l" (for text_encoder 0), "clip_g" (for
|
||||
# text_encoder 1) to keep compatible with the ecosystem.
|
||||
@@ -984,7 +997,9 @@ class TokenEmbeddingsHandler:
|
||||
def retract_embeddings(self):
|
||||
for idx, text_encoder in enumerate(self.text_encoders):
|
||||
index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"]
|
||||
text_encoder.text_model.embeddings.token_embedding.weight.data[index_no_updates] = (
|
||||
(
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data[index_no_updates] = (
|
||||
self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates]
|
||||
.to(device=text_encoder.device)
|
||||
.to(dtype=text_encoder.dtype)
|
||||
@@ -995,11 +1010,15 @@ class TokenEmbeddingsHandler:
|
||||
std_token_embedding = self.embeddings_settings[f"std_token_embedding_{idx}"]
|
||||
|
||||
index_updates = ~index_no_updates
|
||||
new_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates]
|
||||
new_embeddings = (
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data[index_updates]
|
||||
off_ratio = std_token_embedding / new_embeddings.std()
|
||||
|
||||
new_embeddings = new_embeddings * (off_ratio**0.1)
|
||||
text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] = new_embeddings
|
||||
(
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data[index_updates] = new_embeddings
|
||||
|
||||
|
||||
class DreamBoothDataset(Dataset):
|
||||
@@ -2083,8 +2102,10 @@ def main(args):
|
||||
text_encoder_two.train()
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
if args.train_text_encoder:
|
||||
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
|
||||
accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)
|
||||
_te_one = accelerator.unwrap_model(text_encoder_one)
|
||||
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)
|
||||
_te_two = accelerator.unwrap_model(text_encoder_two)
|
||||
(_te_two.text_model if hasattr(_te_two, "text_model") else _te_two).embeddings.requires_grad_(True)
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
if pivoted:
|
||||
|
||||
@@ -874,10 +874,11 @@ def main(args):
|
||||
token_embeds[x] = token_embeds[y]
|
||||
|
||||
# Freeze all parameters except for the token embeddings in text encoder
|
||||
text_module = text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
params_to_freeze = itertools.chain(
|
||||
text_encoder.text_model.encoder.parameters(),
|
||||
text_encoder.text_model.final_layer_norm.parameters(),
|
||||
text_encoder.text_model.embeddings.position_embedding.parameters(),
|
||||
text_module.encoder.parameters(),
|
||||
text_module.final_layer_norm.parameters(),
|
||||
text_module.embeddings.position_embedding.parameters(),
|
||||
)
|
||||
freeze_params(params_to_freeze)
|
||||
########################################################
|
||||
|
||||
@@ -1691,7 +1691,8 @@ def main(args):
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one.train()
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
|
||||
_te_one = unwrap_model(text_encoder_one)
|
||||
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
models_to_accumulate = [transformer]
|
||||
|
||||
@@ -1740,9 +1740,12 @@ def main(args):
|
||||
prompt_embeds = prompt_embeds_cache[step]
|
||||
text_ids = text_ids_cache[step]
|
||||
else:
|
||||
num_repeat_elements = len(prompts)
|
||||
prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1)
|
||||
text_ids = text_ids.repeat(num_repeat_elements, 1, 1)
|
||||
# With prior preservation, prompt_embeds/text_ids already contain [instance, class] entries,
|
||||
# while collate_fn orders batches as [inst1..instB, class1..classB]. Repeat each entry along
|
||||
# dim 0 to preserve that grouping instead of interleaving [inst, class, inst, class, ...].
|
||||
num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts)
|
||||
prompt_embeds = prompt_embeds.repeat_interleave(num_repeat_elements, dim=0)
|
||||
text_ids = text_ids.repeat_interleave(num_repeat_elements, dim=0)
|
||||
|
||||
# Convert images to latent space
|
||||
if args.cache_latents:
|
||||
@@ -1809,10 +1812,11 @@ def main(args):
|
||||
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
|
||||
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
|
||||
target, target_prior = torch.chunk(target, 2, dim=0)
|
||||
weighting, weighting_prior = torch.chunk(weighting, 2, dim=0)
|
||||
|
||||
# Compute prior loss
|
||||
prior_loss = torch.mean(
|
||||
(weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
|
||||
(weighting_prior.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
|
||||
target_prior.shape[0], -1
|
||||
),
|
||||
1,
|
||||
|
||||
@@ -1680,9 +1680,12 @@ def main(args):
|
||||
prompt_embeds = prompt_embeds_cache[step]
|
||||
text_ids = text_ids_cache[step]
|
||||
else:
|
||||
num_repeat_elements = len(prompts)
|
||||
prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1)
|
||||
text_ids = text_ids.repeat(num_repeat_elements, 1, 1)
|
||||
# With prior preservation, prompt_embeds/text_ids already contain [instance, class] entries,
|
||||
# while collate_fn orders batches as [inst1..instB, class1..classB]. Repeat each entry along
|
||||
# dim 0 to preserve that grouping instead of interleaving [inst, class, inst, class, ...].
|
||||
num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts)
|
||||
prompt_embeds = prompt_embeds.repeat_interleave(num_repeat_elements, dim=0)
|
||||
text_ids = text_ids.repeat_interleave(num_repeat_elements, dim=0)
|
||||
|
||||
# Convert images to latent space
|
||||
if args.cache_latents:
|
||||
@@ -1752,10 +1755,11 @@ def main(args):
|
||||
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
|
||||
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
|
||||
target, target_prior = torch.chunk(target, 2, dim=0)
|
||||
weighting, weighting_prior = torch.chunk(weighting, 2, dim=0)
|
||||
|
||||
# Compute prior loss
|
||||
prior_loss = torch.mean(
|
||||
(weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
|
||||
(weighting_prior.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
|
||||
target_prior.shape[0], -1
|
||||
),
|
||||
1,
|
||||
|
||||
@@ -1896,7 +1896,8 @@ def main(args):
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one.train()
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
|
||||
_te_one = unwrap_model(text_encoder_one)
|
||||
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
models_to_accumulate = [transformer]
|
||||
|
||||
@@ -1719,8 +1719,10 @@ def main(args):
|
||||
text_encoder_two.train()
|
||||
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
|
||||
accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)
|
||||
_te_one = accelerator.unwrap_model(text_encoder_one)
|
||||
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)
|
||||
_te_two = accelerator.unwrap_model(text_encoder_two)
|
||||
(_te_two.text_model if hasattr(_te_two, "text_model") else _te_two).embeddings.requires_grad_(True)
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
models_to_accumulate = [transformer]
|
||||
|
||||
@@ -1661,8 +1661,10 @@ def main(args):
|
||||
text_encoder_two.train()
|
||||
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
|
||||
accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)
|
||||
_te_one = accelerator.unwrap_model(text_encoder_one)
|
||||
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)
|
||||
_te_two = accelerator.unwrap_model(text_encoder_two)
|
||||
(_te_two.text_model if hasattr(_te_two, "text_model") else _te_two).embeddings.requires_grad_(True)
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(unet):
|
||||
|
||||
@@ -45,7 +45,16 @@ def annotate_pipeline(pipe):
|
||||
method = getattr(component, method_name, None)
|
||||
if method is None:
|
||||
continue
|
||||
setattr(component, method_name, annotate(method, label))
|
||||
|
||||
# Apply fix ONLY for LTX2 pipelines
|
||||
if "LTX2" in pipe.__class__.__name__:
|
||||
func = getattr(method, "__func__", method)
|
||||
wrapped = annotate(func, label)
|
||||
bound_method = wrapped.__get__(component, type(component))
|
||||
setattr(component, method_name, bound_method)
|
||||
else:
|
||||
# keep original behavior for other pipelines
|
||||
setattr(component, method_name, annotate(method, label))
|
||||
|
||||
# Annotate pipeline-level methods
|
||||
if hasattr(pipe, "encode_prompt"):
|
||||
|
||||
@@ -51,7 +51,42 @@ python flux_inference.py
|
||||
|
||||
The script loads the text encoders onto the CPU and the Flux transformer and VAE models onto the TPU. The first time the script runs, the compilation time is longer, while the cache stores the compiled programs. On subsequent runs, compilation is much faster and the subsequent passes being the fastest.
|
||||
|
||||
On a Trillium v6e-4, you should expect ~6 sec / 4 images or 1.5 sec / image (as devices run generation in parallel):
|
||||
On a Trillium v6e-4, you should expect ~6 sec / 4 images or 1.5 sec / image (as devices run generation in parallel).
|
||||
|
||||
> **Note:** `flux_inference.py` uses `xmp.spawn` (one process per chip) and requires the full model to fit on a single chip. If you run into OOM errors (e.g., on v5e with 16GB HBM per chip), use the SPMD version instead — see below.
|
||||
|
||||
### SPMD version (for v5e-8 and similar)
|
||||
|
||||
On TPU configurations where a single chip cannot hold the full FLUX transformer (~16GB in bf16), use `flux_inference_spmd.py`. This script uses PyTorch/XLA SPMD to shard the transformer across multiple chips using a `(data, model)` mesh — 4-way model parallel so each chip holds ~4GB of weights, with the remaining chips for data parallelism.
|
||||
|
||||
```bash
|
||||
python flux_inference_spmd.py --schnell
|
||||
```
|
||||
|
||||
Key differences from `flux_inference.py`:
|
||||
- **Single-process SPMD** instead of multi-process `xmp.spawn` — the XLA compiler handles all collective communication transparently.
|
||||
- **Transformer weights are sharded** across the `"model"` mesh axis using `xs.mark_sharding`.
|
||||
- **VAE lives on CPU**, moved to XLA only for decode (then moved back), since the transformer stays on device throughout.
|
||||
- **Text encoding** runs on CPU before loading the transformer.
|
||||
|
||||
On a v5litepod-8 (v5e, 8 chips, 16GB HBM each) with FLUX.1-schnell, expect ~1.76 sec/image at steady state (after compilation):
|
||||
|
||||
```
|
||||
2026-04-15 02:24:30 [info ] SPMD mesh: (2, 4), axes: ('data', 'model'), devices: 8
|
||||
2026-04-15 02:24:30 [info ] encoding prompt on CPU...
|
||||
2026-04-15 02:26:20 [info ] loading VAE on CPU...
|
||||
2026-04-15 02:26:20 [info ] loading flux transformer from black-forest-labs/FLUX.1-schnell
|
||||
2026-04-15 02:27:22 [info ] starting compilation run...
|
||||
2026-04-15 02:52:55 [info ] compilation took 1533.4575625509997 sec.
|
||||
2026-04-15 02:52:56 [info ] starting inference run...
|
||||
2026-04-15 02:56:11 [info ] inference time: 195.74092420299985
|
||||
2026-04-15 02:56:13 [info ] inference time: 1.7625778899996476
|
||||
2026-04-15 02:56:13 [info ] avg. inference over 2 iterations took 98.75175104649975 sec.
|
||||
```
|
||||
|
||||
The first inference iteration includes VAE compilation (~195s). The second iteration shows the true steady-state speed (~1.76s).
|
||||
|
||||
### v6e-4 results (original `flux_inference.py`)
|
||||
|
||||
```bash
|
||||
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
|
||||
|
||||
@@ -0,0 +1,193 @@
|
||||
"""FLUX inference on TPU using PyTorch/XLA SPMD.
|
||||
|
||||
Uses SPMD to shard the transformer across multiple TPU chips, enabling
|
||||
inference on devices where the model doesn't fit on a single chip (e.g., v5e).
|
||||
The VAE is loaded on CPU at startup, moved to XLA for decode, then moved back.
|
||||
"""
|
||||
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
from time import perf_counter
|
||||
|
||||
import numpy as np
|
||||
import structlog
|
||||
import torch
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.debug.metrics as met
|
||||
import torch_xla.debug.profiler as xp
|
||||
import torch_xla.distributed.spmd as xs
|
||||
import torch_xla.runtime as xr
|
||||
from torch_xla.experimental.custom_kernel import FlashAttention
|
||||
|
||||
from diffusers import AutoencoderKL, FluxPipeline
|
||||
|
||||
|
||||
cache_path = Path("/tmp/data/compiler_cache_eXp")
|
||||
cache_path.mkdir(parents=True, exist_ok=True)
|
||||
xr.initialize_cache(str(cache_path), readonly=False)
|
||||
xr.use_spmd()
|
||||
|
||||
logger = structlog.get_logger()
|
||||
metrics_filepath = "/tmp/metrics_report.txt"
|
||||
VAE_SCALE_FACTOR = 8
|
||||
|
||||
|
||||
def _vae_decode(latents, vae, height, width, device):
|
||||
"""Move VAE to XLA, decode latents, move VAE back to CPU."""
|
||||
vae.to(device)
|
||||
latents = FluxPipeline._unpack_latents(latents, height, width, VAE_SCALE_FACTOR)
|
||||
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
|
||||
with torch.no_grad():
|
||||
image = vae.decode(latents, return_dict=False)[0]
|
||||
vae.to("cpu")
|
||||
return image
|
||||
|
||||
|
||||
def main(args):
|
||||
# --- SPMD mesh: 4-way model parallel to fit transformer + VAE on v5e chips ---
|
||||
num_devices = xr.global_runtime_device_count()
|
||||
if num_devices >= 4:
|
||||
mesh = xs.Mesh(np.arange(num_devices), (num_devices // 4, 4), ("data", "model"))
|
||||
else:
|
||||
NotImplementedError
|
||||
xs.set_global_mesh(mesh)
|
||||
logger.info(f"SPMD mesh: {mesh.mesh_shape}, axes: {mesh.axis_names}, devices: {num_devices}")
|
||||
|
||||
# --- Profiler ---
|
||||
profile_path = Path("/tmp/data/profiler_out_eXp")
|
||||
profile_path.mkdir(parents=True, exist_ok=True)
|
||||
profiler_port = 9012
|
||||
profile_duration = args.profile_duration
|
||||
if args.profile:
|
||||
logger.info(f"starting profiler on port {profiler_port}")
|
||||
_ = xp.start_server(profiler_port)
|
||||
|
||||
device = xm.xla_device()
|
||||
|
||||
# --- Checkpoint ---
|
||||
if args.schnell:
|
||||
ckpt_id = "black-forest-labs/FLUX.1-schnell"
|
||||
else:
|
||||
ckpt_id = "black-forest-labs/FLUX.1-dev"
|
||||
|
||||
# --- Text encoding (CPU) ---
|
||||
prompt = "photograph of an electronics chip in the shape of a race car with trillium written on its side"
|
||||
logger.info("encoding prompt on CPU...")
|
||||
text_pipe = FluxPipeline.from_pretrained(ckpt_id, transformer=None, vae=None, torch_dtype=torch.bfloat16).to("cpu")
|
||||
with torch.no_grad():
|
||||
prompt_embeds, pooled_prompt_embeds, _ = text_pipe.encode_prompt(
|
||||
prompt=prompt, prompt_2=None, max_sequence_length=512
|
||||
)
|
||||
image_processor = text_pipe.image_processor
|
||||
del text_pipe
|
||||
|
||||
# --- Load VAE on CPU (moved to XLA only for decode) ---
|
||||
logger.info("loading VAE on CPU...")
|
||||
vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=torch.bfloat16)
|
||||
|
||||
# --- Load transformer and shard ---
|
||||
logger.info(f"loading flux transformer from {ckpt_id}")
|
||||
flux_pipe = FluxPipeline.from_pretrained(
|
||||
ckpt_id,
|
||||
text_encoder=None,
|
||||
tokenizer=None,
|
||||
text_encoder_2=None,
|
||||
tokenizer_2=None,
|
||||
vae=None,
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to(device)
|
||||
|
||||
for name, param in flux_pipe.transformer.named_parameters():
|
||||
if param.dim() >= 2:
|
||||
spec = [None] * param.dim()
|
||||
largest_dim = max(range(param.dim()), key=lambda d: param.shape[d])
|
||||
spec[largest_dim] = "model"
|
||||
xs.mark_sharding(param, mesh, tuple(spec))
|
||||
|
||||
flux_pipe.transformer.enable_xla_flash_attention(partition_spec=("data", None, None, None), is_flux=True)
|
||||
FlashAttention.DEFAULT_BLOCK_SIZES = {
|
||||
"block_q": 1536,
|
||||
"block_k_major": 1536,
|
||||
"block_k": 1536,
|
||||
"block_b": 1536,
|
||||
"block_q_major_dkv": 1536,
|
||||
"block_k_major_dkv": 1536,
|
||||
"block_q_dkv": 1536,
|
||||
"block_k_dkv": 1536,
|
||||
"block_q_dq": 1536,
|
||||
"block_k_dq": 1536,
|
||||
"block_k_major_dq": 1536,
|
||||
}
|
||||
|
||||
width = args.width
|
||||
height = args.height
|
||||
guidance = args.guidance
|
||||
n_steps = 4 if args.schnell else 28
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.to(device)
|
||||
xs.mark_sharding(prompt_embeds, mesh, ("data", None, None))
|
||||
xs.mark_sharding(pooled_prompt_embeds, mesh, ("data", None))
|
||||
|
||||
# --- Compilation run ---
|
||||
logger.info("starting compilation run...")
|
||||
ts = perf_counter()
|
||||
latents = flux_pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
num_inference_steps=28,
|
||||
guidance_scale=guidance,
|
||||
height=height,
|
||||
width=width,
|
||||
output_type="latent",
|
||||
).images
|
||||
image = _vae_decode(latents, vae, height, width, device)
|
||||
image = image_processor.postprocess(image)[0]
|
||||
logger.info(f"compilation took {perf_counter() - ts} sec.")
|
||||
image.save("/tmp/compile_out.png")
|
||||
|
||||
# --- Inference loop ---
|
||||
seed = 4096 if args.seed is None else args.seed
|
||||
xm.set_rng_state(seed=seed, device=device)
|
||||
times = []
|
||||
logger.info("starting inference run...")
|
||||
for _ in range(args.itters):
|
||||
ts = perf_counter()
|
||||
|
||||
if args.profile:
|
||||
xp.trace_detached(f"localhost:{profiler_port}", str(profile_path), duration_ms=profile_duration)
|
||||
latents = flux_pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
num_inference_steps=n_steps,
|
||||
guidance_scale=guidance,
|
||||
height=height,
|
||||
width=width,
|
||||
output_type="latent",
|
||||
).images
|
||||
image = _vae_decode(latents, vae, height, width, device)
|
||||
image = image_processor.postprocess(image)[0]
|
||||
inference_time = perf_counter() - ts
|
||||
logger.info(f"inference time: {inference_time}")
|
||||
times.append(inference_time)
|
||||
|
||||
logger.info(f"avg. inference over {args.itters} iterations took {sum(times) / len(times)} sec.")
|
||||
image.save("/tmp/inference_out.png")
|
||||
metrics_report = met.metrics_report()
|
||||
with open(metrics_filepath, "w+") as fout:
|
||||
fout.write(metrics_report)
|
||||
logger.info(f"saved metric information as {metrics_filepath}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--schnell", action="store_true", help="run flux schnell instead of dev")
|
||||
parser.add_argument("--width", type=int, default=1024, help="width of the image to generate")
|
||||
parser.add_argument("--height", type=int, default=1024, help="height of the image to generate")
|
||||
parser.add_argument("--guidance", type=float, default=3.5, help="guidance strength for dev")
|
||||
parser.add_argument("--seed", type=int, default=None, help="seed for inference")
|
||||
parser.add_argument("--profile", action="store_true", help="enable profiling")
|
||||
parser.add_argument("--profile-duration", type=int, default=10000, help="duration for profiling in msec.")
|
||||
parser.add_argument("--itters", type=int, default=15, help="items to run inference and get avg time in sec.")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@@ -702,9 +702,10 @@ def main():
|
||||
vae.requires_grad_(False)
|
||||
unet.requires_grad_(False)
|
||||
# Freeze all parameters except for the token embeddings in text encoder
|
||||
text_encoder.text_model.encoder.requires_grad_(False)
|
||||
text_encoder.text_model.final_layer_norm.requires_grad_(False)
|
||||
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
|
||||
text_module = text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
text_module.encoder.requires_grad_(False)
|
||||
text_module.final_layer_norm.requires_grad_(False)
|
||||
text_module.embeddings.position_embedding.requires_grad_(False)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
# Keep unet in train mode if we are using gradient checkpointing to save memory.
|
||||
|
||||
@@ -717,12 +717,14 @@ def main():
|
||||
unet.requires_grad_(False)
|
||||
|
||||
# Freeze all parameters except for the token embeddings in text encoder
|
||||
text_encoder_1.text_model.encoder.requires_grad_(False)
|
||||
text_encoder_1.text_model.final_layer_norm.requires_grad_(False)
|
||||
text_encoder_1.text_model.embeddings.position_embedding.requires_grad_(False)
|
||||
text_encoder_2.text_model.encoder.requires_grad_(False)
|
||||
text_encoder_2.text_model.final_layer_norm.requires_grad_(False)
|
||||
text_encoder_2.text_model.embeddings.position_embedding.requires_grad_(False)
|
||||
text_module_1 = text_encoder_1.text_model if hasattr(text_encoder_1, "text_model") else text_encoder_1
|
||||
text_module_1.encoder.requires_grad_(False)
|
||||
text_module_1.final_layer_norm.requires_grad_(False)
|
||||
text_module_1.embeddings.position_embedding.requires_grad_(False)
|
||||
text_module_2 = text_encoder_2.text_model if hasattr(text_encoder_2, "text_model") else text_encoder_2
|
||||
text_module_2.encoder.requires_grad_(False)
|
||||
text_module_2.final_layer_norm.requires_grad_(False)
|
||||
text_module_2.embeddings.position_embedding.requires_grad_(False)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
text_encoder_1.gradient_checkpointing_enable()
|
||||
@@ -767,8 +769,12 @@ def main():
|
||||
optimizer = optimizer_class(
|
||||
# only optimize the embeddings
|
||||
[
|
||||
text_encoder_1.text_model.embeddings.token_embedding.weight,
|
||||
text_encoder_2.text_model.embeddings.token_embedding.weight,
|
||||
(
|
||||
text_encoder_1.text_model if hasattr(text_encoder_1, "text_model") else text_encoder_1
|
||||
).embeddings.token_embedding.weight,
|
||||
(
|
||||
text_encoder_2.text_model if hasattr(text_encoder_2, "text_model") else text_encoder_2
|
||||
).embeddings.token_embedding.weight,
|
||||
],
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
|
||||
224
scripts/convert_longcat_audio_dit_to_diffusers.py
Normal file
224
scripts/convert_longcat_audio_dit_to_diffusers.py
Normal file
@@ -0,0 +1,224 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Usage:
|
||||
# python scripts/convert_longcat_audio_dit_to_diffusers.py --checkpoint_path /path/to/model --output_path /data/models
|
||||
# python scripts/convert_longcat_audio_dit_to_diffusers.py --repo_id meituan-longcat/LongCat-AudioDiT-1B --output_path /data/models
|
||||
# python scripts/convert_longcat_audio_dit_to_diffusers.py --checkpoint_path /path/to/model --output_path /data/models --dtype fp16
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from safetensors.torch import load_file
|
||||
from transformers import AutoTokenizer, UMT5Config, UMT5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
LongCatAudioDiTPipeline,
|
||||
LongCatAudioDiTTransformer,
|
||||
LongCatAudioDiTVae,
|
||||
)
|
||||
|
||||
|
||||
def find_checkpoint(input_dir: Path):
|
||||
safetensors_file = input_dir / "model.safetensors"
|
||||
if safetensors_file.exists():
|
||||
return input_dir, safetensors_file
|
||||
|
||||
index_file = input_dir / "model.safetensors.index.json"
|
||||
if index_file.exists():
|
||||
with open(index_file) as f:
|
||||
index = json.load(f)
|
||||
weight_map = index.get("weight_map", {})
|
||||
first_weight = list(weight_map.values())[0]
|
||||
return input_dir, input_dir / first_weight
|
||||
|
||||
for subdir in input_dir.iterdir():
|
||||
if subdir.is_dir():
|
||||
safetensors_file = subdir / "model.safetensors"
|
||||
if safetensors_file.exists():
|
||||
return subdir, safetensors_file
|
||||
index_file = subdir / "model.safetensors.index.json"
|
||||
if index_file.exists():
|
||||
with open(index_file) as f:
|
||||
index = json.load(f)
|
||||
weight_map = index.get("weight_map", {})
|
||||
first_weight = list(weight_map.values())[0]
|
||||
return subdir, subdir / first_weight
|
||||
|
||||
raise FileNotFoundError(f"No checkpoint found in {input_dir}")
|
||||
|
||||
|
||||
def convert_longcat_audio_dit(
|
||||
checkpoint_path: str | None = None,
|
||||
repo_id: str | None = None,
|
||||
output_path: str = "",
|
||||
dtype: str = "fp32",
|
||||
text_encoder_model: str = "google/umt5-xxl",
|
||||
):
|
||||
if not checkpoint_path and not repo_id:
|
||||
raise ValueError("Either --checkpoint_path or --repo_id must be provided")
|
||||
if checkpoint_path and repo_id:
|
||||
raise ValueError("Cannot specify both --checkpoint_path and --repo_id")
|
||||
|
||||
dtype_map = {
|
||||
"fp32": torch.float32,
|
||||
"fp16": torch.float16,
|
||||
"bf16": torch.bfloat16,
|
||||
}
|
||||
torch_dtype = dtype_map.get(dtype, torch.float32)
|
||||
|
||||
if repo_id:
|
||||
input_dir = Path(snapshot_download(repo_id, local_files_only=False))
|
||||
model_name = repo_id.split("/")[-1]
|
||||
else:
|
||||
input_dir = Path(checkpoint_path)
|
||||
if not input_dir.exists():
|
||||
raise FileNotFoundError(f"Checkpoint path not found: {checkpoint_path}")
|
||||
model_name = None
|
||||
|
||||
model_dir, checkpoint_path = find_checkpoint(input_dir)
|
||||
if model_name is None:
|
||||
model_name = model_dir.name
|
||||
|
||||
config_path = model_dir / "config.json"
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"config.json not found in {model_dir}")
|
||||
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
state_dict = load_file(checkpoint_path)
|
||||
|
||||
transformer_keys = [k for k in state_dict.keys() if k.startswith("transformer.")]
|
||||
transformer_state_dict = {key[12:]: state_dict[key] for key in transformer_keys}
|
||||
|
||||
vae_keys = [k for k in state_dict.keys() if k.startswith("vae.")]
|
||||
vae_state_dict = {key[4:]: state_dict[key] for key in vae_keys}
|
||||
|
||||
text_encoder_keys = [k for k in state_dict.keys() if k.startswith("text_encoder.")]
|
||||
text_encoder_state_dict = {key[13:]: state_dict[key] for key in text_encoder_keys}
|
||||
|
||||
transformer = LongCatAudioDiTTransformer(
|
||||
dit_dim=config["dit_dim"],
|
||||
dit_depth=config["dit_depth"],
|
||||
dit_heads=config["dit_heads"],
|
||||
dit_text_dim=config["dit_text_dim"],
|
||||
latent_dim=config["latent_dim"],
|
||||
dropout=config.get("dit_dropout", 0.0),
|
||||
bias=config.get("dit_bias", True),
|
||||
cross_attn=config.get("dit_cross_attn", True),
|
||||
adaln_type=config.get("dit_adaln_type", "global"),
|
||||
adaln_use_text_cond=config.get("dit_adaln_use_text_cond", True),
|
||||
long_skip=config.get("dit_long_skip", True),
|
||||
text_conv=config.get("dit_text_conv", True),
|
||||
qk_norm=config.get("dit_qk_norm", True),
|
||||
cross_attn_norm=config.get("dit_cross_attn_norm", False),
|
||||
eps=config.get("dit_eps", 1e-6),
|
||||
use_latent_condition=config.get("dit_use_latent_condition", True),
|
||||
)
|
||||
transformer.load_state_dict(transformer_state_dict, strict=True)
|
||||
transformer = transformer.to(dtype=torch_dtype)
|
||||
|
||||
vae_config = dict(config["vae_config"])
|
||||
vae_config.pop("model_type", None)
|
||||
vae = LongCatAudioDiTVae(**vae_config)
|
||||
vae.load_state_dict(vae_state_dict, strict=True)
|
||||
vae = vae.to(dtype=torch_dtype)
|
||||
|
||||
text_encoder_config = UMT5Config.from_dict(config["text_encoder_config"])
|
||||
text_encoder = UMT5EncoderModel(text_encoder_config)
|
||||
text_missing, text_unexpected = text_encoder.load_state_dict(text_encoder_state_dict, strict=False)
|
||||
|
||||
allowed_missing = {"shared.weight"}
|
||||
unexpected_missing = set(text_missing) - allowed_missing
|
||||
if unexpected_missing:
|
||||
raise RuntimeError(f"Unexpected missing text encoder weights: {sorted(unexpected_missing)}")
|
||||
if text_unexpected:
|
||||
raise RuntimeError(f"Unexpected text encoder weights: {sorted(text_unexpected)}")
|
||||
if "shared.weight" in text_missing:
|
||||
text_encoder.shared.weight.data.copy_(text_encoder.encoder.embed_tokens.weight.data)
|
||||
|
||||
text_encoder = text_encoder.to(dtype=torch_dtype)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(text_encoder_model)
|
||||
|
||||
scheduler_config = {"shift": 1.0, "invert_sigmas": True}
|
||||
scheduler_config.update(config.get("scheduler_config", {}))
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(**scheduler_config)
|
||||
|
||||
pipeline = LongCatAudioDiTPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
pipeline.sample_rate = config.get("sampling_rate", 24000)
|
||||
pipeline.vae_scale_factor = config.get("vae_scale_factor", config.get("latent_hop", 2048))
|
||||
pipeline.max_wav_duration = config.get("max_wav_duration", 30.0)
|
||||
pipeline.text_norm_feat = config.get("text_norm_feat", True)
|
||||
pipeline.text_add_embed = config.get("text_add_embed", True)
|
||||
|
||||
output_path = Path(output_path) / f"{model_name}-Diffusers"
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
pipeline.save_pretrained(output_path)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--checkpoint_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to local model directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo_id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="HuggingFace repo_id to download model",
|
||||
)
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Output directory")
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
default="fp32",
|
||||
choices=["fp32", "fp16", "bf16"],
|
||||
help="Data type for converted weights",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_encoder_model",
|
||||
type=str,
|
||||
default="google/umt5-xxl",
|
||||
help="HuggingFace model ID for text encoder tokenizer",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
convert_longcat_audio_dit(
|
||||
checkpoint_path=args.checkpoint_path,
|
||||
repo_id=args.repo_id,
|
||||
output_path=args.output_path,
|
||||
dtype=args.dtype,
|
||||
text_encoder_model=args.text_encoder_model,
|
||||
)
|
||||
@@ -254,6 +254,8 @@ else:
|
||||
"Kandinsky3UNet",
|
||||
"Kandinsky5Transformer3DModel",
|
||||
"LatteTransformer3DModel",
|
||||
"LongCatAudioDiTTransformer",
|
||||
"LongCatAudioDiTVae",
|
||||
"LongCatImageTransformer2DModel",
|
||||
"LTX2VideoTransformer3DModel",
|
||||
"LTXVideoTransformer3DModel",
|
||||
@@ -599,6 +601,7 @@ else:
|
||||
"LEditsPPPipelineStableDiffusionXL",
|
||||
"LLaDA2Pipeline",
|
||||
"LLaDA2PipelineOutput",
|
||||
"LongCatAudioDiTPipeline",
|
||||
"LongCatImageEditPipeline",
|
||||
"LongCatImagePipeline",
|
||||
"LTX2ConditionPipeline",
|
||||
@@ -1058,6 +1061,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Kandinsky3UNet,
|
||||
Kandinsky5Transformer3DModel,
|
||||
LatteTransformer3DModel,
|
||||
LongCatAudioDiTTransformer,
|
||||
LongCatAudioDiTVae,
|
||||
LongCatImageTransformer2DModel,
|
||||
LTX2VideoTransformer3DModel,
|
||||
LTXVideoTransformer3DModel,
|
||||
@@ -1378,6 +1383,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
LLaDA2Pipeline,
|
||||
LLaDA2PipelineOutput,
|
||||
LongCatAudioDiTPipeline,
|
||||
LongCatImageEditPipeline,
|
||||
LongCatImagePipeline,
|
||||
LTX2ConditionPipeline,
|
||||
|
||||
@@ -50,6 +50,7 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.autoencoder_kl_qwenimage"] = ["AutoencoderKLQwenImage"]
|
||||
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
|
||||
_import_structure["autoencoders.autoencoder_kl_wan"] = ["AutoencoderKLWan"]
|
||||
_import_structure["autoencoders.autoencoder_longcat_audio_dit"] = ["LongCatAudioDiTVae"]
|
||||
_import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
|
||||
_import_structure["autoencoders.autoencoder_rae"] = ["AutoencoderRAE"]
|
||||
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
|
||||
@@ -112,6 +113,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"]
|
||||
_import_structure["transformers.transformer_longcat_audio_dit"] = ["LongCatAudioDiTTransformer"]
|
||||
_import_structure["transformers.transformer_longcat_image"] = ["LongCatImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_ltx2"] = ["LTX2VideoTransformer3DModel"]
|
||||
@@ -180,6 +182,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderTiny,
|
||||
AutoencoderVidTok,
|
||||
ConsistencyDecoderVAE,
|
||||
LongCatAudioDiTVae,
|
||||
VQModel,
|
||||
)
|
||||
from .cache_utils import CacheMixin
|
||||
@@ -233,6 +236,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
HunyuanVideoTransformer3DModel,
|
||||
Kandinsky5Transformer3DModel,
|
||||
LatteTransformer3DModel,
|
||||
LongCatAudioDiTTransformer,
|
||||
LongCatImageTransformer2DModel,
|
||||
LTX2VideoTransformer3DModel,
|
||||
LTXVideoTransformer3DModel,
|
||||
|
||||
@@ -19,6 +19,7 @@ from .autoencoder_kl_mochi import AutoencoderKLMochi
|
||||
from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage
|
||||
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
|
||||
from .autoencoder_kl_wan import AutoencoderKLWan
|
||||
from .autoencoder_longcat_audio_dit import LongCatAudioDiTVae
|
||||
from .autoencoder_oobleck import AutoencoderOobleck
|
||||
from .autoencoder_rae import AutoencoderRAE
|
||||
from .autoencoder_tiny import AutoencoderTiny
|
||||
|
||||
@@ -0,0 +1,400 @@
|
||||
# Copyright 2026 MeiTuan LongCat-AudioDiT Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Adapted from the LongCat-AudioDiT reference implementation:
|
||||
# https://github.com/meituan-longcat/LongCat-AudioDiT
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import BaseOutput
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import AutoencoderMixin
|
||||
|
||||
|
||||
def _wn_conv1d(in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, bias=True):
|
||||
return weight_norm(nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias))
|
||||
|
||||
|
||||
def _wn_conv_transpose1d(*args, **kwargs):
|
||||
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
||||
|
||||
|
||||
class Snake1d(nn.Module):
|
||||
def __init__(self, channels: int, alpha_logscale: bool = True):
|
||||
super().__init__()
|
||||
self.alpha_logscale = alpha_logscale
|
||||
self.alpha = nn.Parameter(torch.zeros(channels))
|
||||
self.beta = nn.Parameter(torch.zeros(channels))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
alpha = self.alpha[None, :, None]
|
||||
beta = self.beta[None, :, None]
|
||||
if self.alpha_logscale:
|
||||
alpha = torch.exp(alpha)
|
||||
beta = torch.exp(beta)
|
||||
return hidden_states + (1.0 / (beta + 1e-9)) * torch.sin(hidden_states * alpha).pow(2)
|
||||
|
||||
|
||||
def _get_vae_activation(name: str, channels: int = 0) -> nn.Module:
|
||||
if name == "elu":
|
||||
act = nn.ELU()
|
||||
elif name == "snake":
|
||||
act = Snake1d(channels)
|
||||
else:
|
||||
raise ValueError(f"Unknown activation: {name}")
|
||||
return act
|
||||
|
||||
|
||||
def _pixel_shuffle_1d(hidden_states: torch.Tensor, factor: int) -> torch.Tensor:
|
||||
batch, channels, width = hidden_states.size()
|
||||
return (
|
||||
hidden_states.view(batch, channels // factor, factor, width)
|
||||
.permute(0, 1, 3, 2)
|
||||
.contiguous()
|
||||
.view(batch, channels // factor, width * factor)
|
||||
)
|
||||
|
||||
|
||||
class DownsampleShortcut(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, factor: int):
|
||||
super().__init__()
|
||||
self.factor = factor
|
||||
self.group_size = in_channels * factor // out_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch, channels, width = hidden_states.shape
|
||||
hidden_states = (
|
||||
hidden_states.view(batch, channels, width // self.factor, self.factor)
|
||||
.permute(0, 1, 3, 2)
|
||||
.contiguous()
|
||||
.view(batch, channels * self.factor, width // self.factor)
|
||||
)
|
||||
return hidden_states.view(batch, self.out_channels, self.group_size, width // self.factor).mean(dim=2)
|
||||
|
||||
|
||||
class UpsampleShortcut(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, factor: int):
|
||||
super().__init__()
|
||||
self.factor = factor
|
||||
self.repeats = out_channels * factor // in_channels
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = hidden_states.repeat_interleave(self.repeats, dim=1)
|
||||
return _pixel_shuffle_1d(hidden_states, self.factor)
|
||||
|
||||
|
||||
class VaeResidualUnit(nn.Module):
|
||||
def __init__(
|
||||
self, in_channels: int, out_channels: int, dilation: int, kernel_size: int = 7, act_fn: str = "snake"
|
||||
):
|
||||
super().__init__()
|
||||
padding = (dilation * (kernel_size - 1)) // 2
|
||||
self.layers = nn.Sequential(
|
||||
_get_vae_activation(act_fn, channels=out_channels),
|
||||
_wn_conv1d(in_channels, out_channels, kernel_size, dilation=dilation, padding=padding),
|
||||
_get_vae_activation(act_fn, channels=out_channels),
|
||||
_wn_conv1d(out_channels, out_channels, kernel_size=1),
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
return hidden_states + self.layers(hidden_states)
|
||||
|
||||
|
||||
class VaeEncoderBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
stride: int,
|
||||
act_fn: str = "snake",
|
||||
downsample_shortcut: str = "none",
|
||||
):
|
||||
super().__init__()
|
||||
layers = [
|
||||
VaeResidualUnit(in_channels, in_channels, dilation=1, act_fn=act_fn),
|
||||
VaeResidualUnit(in_channels, in_channels, dilation=3, act_fn=act_fn),
|
||||
VaeResidualUnit(in_channels, in_channels, dilation=9, act_fn=act_fn),
|
||||
]
|
||||
layers.append(_get_vae_activation(act_fn, channels=in_channels))
|
||||
layers.append(
|
||||
_wn_conv1d(in_channels, out_channels, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2))
|
||||
)
|
||||
self.layers = nn.Sequential(*layers)
|
||||
self.residual = (
|
||||
DownsampleShortcut(in_channels, out_channels, stride) if downsample_shortcut == "averaging" else None
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
output_hidden_states = self.layers(hidden_states)
|
||||
if self.residual is not None:
|
||||
residual = self.residual(hidden_states)
|
||||
output_hidden_states = output_hidden_states + residual
|
||||
return output_hidden_states
|
||||
|
||||
|
||||
class VaeDecoderBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
stride: int,
|
||||
act_fn: str = "snake",
|
||||
upsample_shortcut: str = "none",
|
||||
):
|
||||
super().__init__()
|
||||
layers = [
|
||||
_get_vae_activation(act_fn, channels=in_channels),
|
||||
_wn_conv_transpose1d(
|
||||
in_channels, out_channels, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2)
|
||||
),
|
||||
VaeResidualUnit(out_channels, out_channels, dilation=1, act_fn=act_fn),
|
||||
VaeResidualUnit(out_channels, out_channels, dilation=3, act_fn=act_fn),
|
||||
VaeResidualUnit(out_channels, out_channels, dilation=9, act_fn=act_fn),
|
||||
]
|
||||
self.layers = nn.Sequential(*layers)
|
||||
self.residual = (
|
||||
UpsampleShortcut(in_channels, out_channels, stride) if upsample_shortcut == "duplicating" else None
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
output_hidden_states = self.layers(hidden_states)
|
||||
if self.residual is not None:
|
||||
residual = self.residual(hidden_states)
|
||||
output_hidden_states = output_hidden_states + residual
|
||||
return output_hidden_states
|
||||
|
||||
|
||||
class AudioDiTVaeEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 1,
|
||||
channels: int = 128,
|
||||
c_mults: list[int] | None = None,
|
||||
strides: list[int] | None = None,
|
||||
latent_dim: int = 64,
|
||||
encoder_latent_dim: int = 128,
|
||||
act_fn: str = "snake",
|
||||
downsample_shortcut: str = "averaging",
|
||||
out_shortcut: str = "averaging",
|
||||
):
|
||||
super().__init__()
|
||||
c_mults = [1] + (c_mults or [1, 2, 4, 8, 16])
|
||||
strides = list(strides or [2] * (len(c_mults) - 1))
|
||||
if len(strides) < len(c_mults) - 1:
|
||||
strides.extend([strides[-1] if strides else 2] * (len(c_mults) - 1 - len(strides)))
|
||||
else:
|
||||
strides = strides[: len(c_mults) - 1]
|
||||
channels_base = channels
|
||||
layers = [_wn_conv1d(in_channels, c_mults[0] * channels_base, kernel_size=7, padding=3)]
|
||||
for idx in range(len(c_mults) - 1):
|
||||
layers.append(
|
||||
VaeEncoderBlock(
|
||||
c_mults[idx] * channels_base,
|
||||
c_mults[idx + 1] * channels_base,
|
||||
strides[idx],
|
||||
act_fn=act_fn,
|
||||
downsample_shortcut=downsample_shortcut,
|
||||
)
|
||||
)
|
||||
layers.append(_wn_conv1d(c_mults[-1] * channels_base, encoder_latent_dim, kernel_size=3, padding=1))
|
||||
self.layers = nn.Sequential(*layers)
|
||||
self.shortcut = (
|
||||
DownsampleShortcut(c_mults[-1] * channels_base, encoder_latent_dim, 1)
|
||||
if out_shortcut == "averaging"
|
||||
else None
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.layers[:-1](hidden_states)
|
||||
output_hidden_states = self.layers[-1](hidden_states)
|
||||
if self.shortcut is not None:
|
||||
shortcut = self.shortcut(hidden_states)
|
||||
output_hidden_states = output_hidden_states + shortcut
|
||||
return output_hidden_states
|
||||
|
||||
|
||||
class AudioDiTVaeDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 1,
|
||||
channels: int = 128,
|
||||
c_mults: list[int] | None = None,
|
||||
strides: list[int] | None = None,
|
||||
latent_dim: int = 64,
|
||||
act_fn: str = "snake",
|
||||
in_shortcut: str = "duplicating",
|
||||
final_tanh: bool = False,
|
||||
upsample_shortcut: str = "duplicating",
|
||||
):
|
||||
super().__init__()
|
||||
c_mults = [1] + (c_mults or [1, 2, 4, 8, 16])
|
||||
strides = list(strides or [2] * (len(c_mults) - 1))
|
||||
if len(strides) < len(c_mults) - 1:
|
||||
strides.extend([strides[-1] if strides else 2] * (len(c_mults) - 1 - len(strides)))
|
||||
else:
|
||||
strides = strides[: len(c_mults) - 1]
|
||||
channels_base = channels
|
||||
|
||||
self.shortcut = (
|
||||
UpsampleShortcut(latent_dim, c_mults[-1] * channels_base, 1) if in_shortcut == "duplicating" else None
|
||||
)
|
||||
|
||||
layers = [_wn_conv1d(latent_dim, c_mults[-1] * channels_base, kernel_size=7, padding=3)]
|
||||
for idx in range(len(c_mults) - 1, 0, -1):
|
||||
layers.append(
|
||||
VaeDecoderBlock(
|
||||
c_mults[idx] * channels_base,
|
||||
c_mults[idx - 1] * channels_base,
|
||||
strides[idx - 1],
|
||||
act_fn=act_fn,
|
||||
upsample_shortcut=upsample_shortcut,
|
||||
)
|
||||
)
|
||||
layers.append(_get_vae_activation(act_fn, channels=c_mults[0] * channels_base))
|
||||
layers.append(_wn_conv1d(c_mults[0] * channels_base, in_channels, kernel_size=7, padding=3, bias=False))
|
||||
layers.append(nn.Tanh() if final_tanh else nn.Identity())
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if self.shortcut is None:
|
||||
return self.layers(hidden_states)
|
||||
hidden_states = self.shortcut(hidden_states) + self.layers[0](hidden_states)
|
||||
return self.layers[1:](hidden_states)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LongCatAudioDiTVaeEncoderOutput(BaseOutput):
|
||||
latents: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class LongCatAudioDiTVaeDecoderOutput(BaseOutput):
|
||||
sample: torch.Tensor
|
||||
|
||||
|
||||
class LongCatAudioDiTVae(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
_supports_group_offloading = False
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 1,
|
||||
channels: int = 128,
|
||||
c_mults: list[int] | None = None,
|
||||
strides: list[int] | None = None,
|
||||
latent_dim: int = 64,
|
||||
encoder_latent_dim: int = 128,
|
||||
act_fn: str | None = None,
|
||||
use_snake: bool | None = None,
|
||||
downsample_shortcut: str = "averaging",
|
||||
upsample_shortcut: str = "duplicating",
|
||||
out_shortcut: str = "averaging",
|
||||
in_shortcut: str = "duplicating",
|
||||
final_tanh: bool = False,
|
||||
downsampling_ratio: int = 2048,
|
||||
sample_rate: int = 24000,
|
||||
scale: float = 0.71,
|
||||
):
|
||||
super().__init__()
|
||||
if act_fn is None:
|
||||
if use_snake is None:
|
||||
act_fn = "snake"
|
||||
else:
|
||||
act_fn = "snake" if use_snake else "elu"
|
||||
self.encoder = AudioDiTVaeEncoder(
|
||||
in_channels=in_channels,
|
||||
channels=channels,
|
||||
c_mults=c_mults,
|
||||
strides=strides,
|
||||
latent_dim=latent_dim,
|
||||
encoder_latent_dim=encoder_latent_dim,
|
||||
act_fn=act_fn,
|
||||
downsample_shortcut=downsample_shortcut,
|
||||
out_shortcut=out_shortcut,
|
||||
)
|
||||
self.decoder = AudioDiTVaeDecoder(
|
||||
in_channels=in_channels,
|
||||
channels=channels,
|
||||
c_mults=c_mults,
|
||||
strides=strides,
|
||||
latent_dim=latent_dim,
|
||||
act_fn=act_fn,
|
||||
in_shortcut=in_shortcut,
|
||||
final_tanh=final_tanh,
|
||||
upsample_shortcut=upsample_shortcut,
|
||||
)
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
sample_posterior: bool = True,
|
||||
return_dict: bool = True,
|
||||
generator: torch.Generator | None = None,
|
||||
) -> LongCatAudioDiTVaeEncoderOutput | tuple[torch.Tensor]:
|
||||
encoder_dtype = next(self.encoder.parameters()).dtype
|
||||
if sample.dtype != encoder_dtype:
|
||||
sample = sample.to(encoder_dtype)
|
||||
encoded = self.encoder(sample)
|
||||
mean, scale_param = encoded.chunk(2, dim=1)
|
||||
std = F.softplus(scale_param) + 1e-4
|
||||
if sample_posterior:
|
||||
noise = randn_tensor(mean.shape, generator=generator, device=mean.device, dtype=mean.dtype)
|
||||
latents = mean + std * noise
|
||||
else:
|
||||
latents = mean
|
||||
latents = latents / self.config.scale
|
||||
if encoder_dtype != torch.float32:
|
||||
latents = latents.float()
|
||||
if not return_dict:
|
||||
return (latents,)
|
||||
return LongCatAudioDiTVaeEncoderOutput(latents=latents)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(
|
||||
self, latents: torch.Tensor, return_dict: bool = True
|
||||
) -> LongCatAudioDiTVaeDecoderOutput | tuple[torch.Tensor]:
|
||||
decoder_dtype = next(self.decoder.parameters()).dtype
|
||||
latents = latents * self.config.scale
|
||||
if latents.dtype != decoder_dtype:
|
||||
latents = latents.to(decoder_dtype)
|
||||
decoded = self.decoder(latents)
|
||||
if decoder_dtype != torch.float32:
|
||||
decoded = decoded.float()
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
return LongCatAudioDiTVaeDecoderOutput(sample=decoded)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: torch.Generator | None = None,
|
||||
) -> LongCatAudioDiTVaeDecoderOutput | tuple[torch.Tensor]:
|
||||
latents = self.encode(sample, sample_posterior=sample_posterior, return_dict=True, generator=generator).latents
|
||||
decoded = self.decode(latents, return_dict=True).sample
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
return LongCatAudioDiTVaeDecoderOutput(sample=decoded)
|
||||
@@ -36,6 +36,7 @@ if is_torch_available():
|
||||
from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel
|
||||
from .transformer_hunyuanimage import HunyuanImageTransformer2DModel
|
||||
from .transformer_kandinsky import Kandinsky5Transformer3DModel
|
||||
from .transformer_longcat_audio_dit import LongCatAudioDiTTransformer
|
||||
from .transformer_longcat_image import LongCatImageTransformer2DModel
|
||||
from .transformer_ltx import LTXVideoTransformer3DModel
|
||||
from .transformer_ltx2 import LTX2VideoTransformer3DModel
|
||||
|
||||
@@ -0,0 +1,605 @@
|
||||
# Copyright 2026 MeiTuan LongCat-AudioDiT Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Adapted from the LongCat-AudioDiT reference implementation:
|
||||
# https://github.com/meituan-longcat/LongCat-AudioDiT
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import BaseOutput
|
||||
from ...utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph
|
||||
from ..attention import AttentionModuleMixin
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import RMSNorm
|
||||
|
||||
|
||||
@dataclass
|
||||
class LongCatAudioDiTTransformerOutput(BaseOutput):
|
||||
sample: torch.Tensor
|
||||
|
||||
|
||||
class AudioDiTSinusPositionEmbedding(nn.Module):
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, timesteps: torch.Tensor, scale: float = 1000.0) -> torch.Tensor:
|
||||
device = timesteps.device
|
||||
half_dim = self.dim // 2
|
||||
exponent = math.log(10000) / max(half_dim - 1, 1)
|
||||
embeddings = torch.exp(torch.arange(half_dim, device=device).float() * -exponent)
|
||||
embeddings = scale * timesteps.unsqueeze(1) * embeddings.unsqueeze(0)
|
||||
return torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
|
||||
|
||||
|
||||
class AudioDiTTimestepEmbedding(nn.Module):
|
||||
def __init__(self, dim: int, freq_embed_dim: int = 256):
|
||||
super().__init__()
|
||||
self.time_embed = AudioDiTSinusPositionEmbedding(freq_embed_dim)
|
||||
self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
||||
|
||||
def forward(self, timestep: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.time_embed(timestep)
|
||||
return self.time_mlp(hidden_states.to(timestep.dtype))
|
||||
|
||||
|
||||
class AudioDiTRotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim: int, max_position_embeddings: int = 2048, base: float = 100000.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
|
||||
@lru_cache_unless_export(maxsize=128)
|
||||
def _build(self, seq_len: int, device: torch.device | None = None) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
|
||||
if device is not None:
|
||||
inv_freq = inv_freq.to(device)
|
||||
steps = torch.arange(seq_len, dtype=torch.int64, device=inv_freq.device).type_as(inv_freq)
|
||||
freqs = torch.outer(steps, inv_freq)
|
||||
embeddings = torch.cat((freqs, freqs), dim=-1)
|
||||
return embeddings.cos().contiguous(), embeddings.sin().contiguous()
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, seq_len: int | None = None) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
seq_len = hidden_states.shape[1] if seq_len is None else seq_len
|
||||
cos, sin = self._build(max(seq_len, self.max_position_embeddings), hidden_states.device)
|
||||
return cos[:seq_len].to(dtype=hidden_states.dtype), sin[:seq_len].to(dtype=hidden_states.dtype)
|
||||
|
||||
|
||||
def _rotate_half(hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
first, second = hidden_states.chunk(2, dim=-1)
|
||||
return torch.cat((-second, first), dim=-1)
|
||||
|
||||
|
||||
def _apply_rotary_emb(hidden_states: torch.Tensor, rope: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
|
||||
cos, sin = rope
|
||||
cos = cos[None, :, None].to(hidden_states.device)
|
||||
sin = sin[None, :, None].to(hidden_states.device)
|
||||
return (hidden_states.float() * cos + _rotate_half(hidden_states).float() * sin).to(hidden_states.dtype)
|
||||
|
||||
|
||||
class AudioDiTGRN(nn.Module):
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
|
||||
self.beta = nn.Parameter(torch.zeros(1, 1, dim))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
gx = torch.norm(hidden_states, p=2, dim=1, keepdim=True)
|
||||
nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6)
|
||||
return self.gamma * (hidden_states * nx) + self.beta + hidden_states
|
||||
|
||||
|
||||
class AudioDiTConvNeXtV2Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
intermediate_dim: int,
|
||||
dilation: int = 1,
|
||||
kernel_size: int = 7,
|
||||
bias: bool = True,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
padding = (dilation * (kernel_size - 1)) // 2
|
||||
self.dwconv = nn.Conv1d(
|
||||
dim, dim, kernel_size=kernel_size, padding=padding, groups=dim, dilation=dilation, bias=bias
|
||||
)
|
||||
self.norm = nn.LayerNorm(dim, eps=eps)
|
||||
self.pwconv1 = nn.Linear(dim, intermediate_dim, bias=bias)
|
||||
self.act = nn.SiLU()
|
||||
self.grn = AudioDiTGRN(intermediate_dim)
|
||||
self.pwconv2 = nn.Linear(intermediate_dim, dim, bias=bias)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.dwconv(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = self.pwconv1(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.grn(hidden_states)
|
||||
hidden_states = self.pwconv2(hidden_states)
|
||||
return residual + hidden_states
|
||||
|
||||
|
||||
class AudioDiTEmbedder(nn.Module):
|
||||
def __init__(self, in_dim: int, out_dim: int):
|
||||
super().__init__()
|
||||
self.proj = nn.Sequential(nn.Linear(in_dim, out_dim), nn.SiLU(), nn.Linear(out_dim, out_dim))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, mask: torch.BoolTensor | None = None) -> torch.Tensor:
|
||||
if mask is not None:
|
||||
hidden_states = hidden_states.masked_fill(mask.logical_not().unsqueeze(-1), 0.0)
|
||||
hidden_states = self.proj(hidden_states)
|
||||
if mask is not None:
|
||||
hidden_states = hidden_states.masked_fill(mask.logical_not().unsqueeze(-1), 0.0)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AudioDiTAdaLNMLP(nn.Module):
|
||||
def __init__(self, in_dim: int, out_dim: int, bias: bool = True):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(nn.SiLU(), nn.Linear(in_dim, out_dim, bias=bias))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
return self.mlp(hidden_states)
|
||||
|
||||
|
||||
class AudioDiTAdaLayerNormZeroFinal(nn.Module):
|
||||
def __init__(self, dim: int, bias: bool = True, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(dim, dim * 2, bias=bias)
|
||||
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, embedding: torch.Tensor) -> torch.Tensor:
|
||||
embedding = self.linear(self.silu(embedding))
|
||||
scale, shift = torch.chunk(embedding, 2, dim=-1)
|
||||
hidden_states = self.norm(hidden_states.float()).type_as(hidden_states)
|
||||
if scale.ndim == 2:
|
||||
hidden_states = hidden_states * (1 + scale)[:, None, :] + shift[:, None, :]
|
||||
else:
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AudioDiTSelfAttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "AudioDiTAttention",
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.BoolTensor | None = None,
|
||||
audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size = hidden_states.shape[0]
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
if attn.qk_norm:
|
||||
query = attn.q_norm(query)
|
||||
key = attn.k_norm(key)
|
||||
|
||||
head_dim = attn.inner_dim // attn.heads
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim)
|
||||
|
||||
if audio_rotary_emb is not None:
|
||||
query = _apply_rotary_emb(query, audio_rotary_emb)
|
||||
key = _apply_rotary_emb(key, audio_rotary_emb)
|
||||
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
if attention_mask is not None:
|
||||
hidden_states = hidden_states * attention_mask[:, :, None, None].to(hidden_states.dtype)
|
||||
|
||||
hidden_states = hidden_states.flatten(2, 3).to(query.dtype)
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AudioDiTAttention(nn.Module, AttentionModuleMixin):
|
||||
def __init__(
|
||||
self,
|
||||
q_dim: int,
|
||||
kv_dim: int | None,
|
||||
heads: int,
|
||||
dim_head: int,
|
||||
dropout: float = 0.0,
|
||||
bias: bool = True,
|
||||
qk_norm: bool = False,
|
||||
eps: float = 1e-6,
|
||||
processor: AttentionModuleMixin | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
kv_dim = q_dim if kv_dim is None else kv_dim
|
||||
self.heads = heads
|
||||
self.inner_dim = dim_head * heads
|
||||
self.to_q = nn.Linear(q_dim, self.inner_dim, bias=bias)
|
||||
self.to_k = nn.Linear(kv_dim, self.inner_dim, bias=bias)
|
||||
self.to_v = nn.Linear(kv_dim, self.inner_dim, bias=bias)
|
||||
self.qk_norm = qk_norm
|
||||
if qk_norm:
|
||||
self.q_norm = RMSNorm(self.inner_dim, eps=eps)
|
||||
self.k_norm = RMSNorm(self.inner_dim, eps=eps)
|
||||
self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, q_dim, bias=bias), nn.Dropout(dropout)])
|
||||
self.set_processor(processor or AudioDiTSelfAttnProcessor())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor | None = None,
|
||||
post_attention_mask: torch.BoolTensor | None = None,
|
||||
attention_mask: torch.BoolTensor | None = None,
|
||||
audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
prompt_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> torch.Tensor:
|
||||
if encoder_hidden_states is None:
|
||||
return self.processor(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
audio_rotary_emb=audio_rotary_emb,
|
||||
)
|
||||
return self.processor(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
post_attention_mask=post_attention_mask,
|
||||
attention_mask=attention_mask,
|
||||
audio_rotary_emb=audio_rotary_emb,
|
||||
prompt_rotary_emb=prompt_rotary_emb,
|
||||
)
|
||||
|
||||
|
||||
class AudioDiTCrossAttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "AudioDiTAttention",
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
post_attention_mask: torch.BoolTensor | None = None,
|
||||
attention_mask: torch.BoolTensor | None = None,
|
||||
audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
prompt_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size = hidden_states.shape[0]
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
if attn.qk_norm:
|
||||
query = attn.q_norm(query)
|
||||
key = attn.k_norm(key)
|
||||
|
||||
head_dim = attn.inner_dim // attn.heads
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim)
|
||||
|
||||
if audio_rotary_emb is not None:
|
||||
query = _apply_rotary_emb(query, audio_rotary_emb)
|
||||
if prompt_rotary_emb is not None:
|
||||
key = _apply_rotary_emb(key, prompt_rotary_emb)
|
||||
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
if post_attention_mask is not None:
|
||||
hidden_states = hidden_states * post_attention_mask[:, :, None, None].to(hidden_states.dtype)
|
||||
|
||||
hidden_states = hidden_states.flatten(2, 3).to(query.dtype)
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AudioDiTFeedForward(nn.Module):
|
||||
def __init__(self, dim: int, mult: float = 4.0, dropout: float = 0.0, bias: bool = True):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
self.ff = nn.Sequential(
|
||||
nn.Linear(dim, inner_dim, bias=bias),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim, bias=bias),
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
return self.ff(hidden_states)
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class AudioDiTBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
cond_dim: int,
|
||||
heads: int,
|
||||
dim_head: int,
|
||||
dropout: float = 0.0,
|
||||
bias: bool = True,
|
||||
qk_norm: bool = False,
|
||||
eps: float = 1e-6,
|
||||
cross_attn: bool = True,
|
||||
cross_attn_norm: bool = False,
|
||||
adaln_type: str = "global",
|
||||
adaln_use_text_cond: bool = True,
|
||||
ff_mult: float = 4.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.adaln_type = adaln_type
|
||||
self.adaln_use_text_cond = adaln_use_text_cond
|
||||
if adaln_type == "local":
|
||||
self.adaln_mlp = AudioDiTAdaLNMLP(dim, dim * 6, bias=True)
|
||||
elif adaln_type == "global":
|
||||
self.adaln_scale_shift = nn.Parameter(torch.randn(dim * 6) / dim**0.5)
|
||||
|
||||
self.self_attn = AudioDiTAttention(
|
||||
dim, None, heads, dim_head, dropout=dropout, bias=bias, qk_norm=qk_norm, eps=eps
|
||||
)
|
||||
|
||||
self.use_cross_attn = cross_attn
|
||||
if cross_attn:
|
||||
self.cross_attn = AudioDiTAttention(
|
||||
dim,
|
||||
cond_dim,
|
||||
heads,
|
||||
dim_head,
|
||||
dropout=dropout,
|
||||
bias=bias,
|
||||
qk_norm=qk_norm,
|
||||
eps=eps,
|
||||
processor=AudioDiTCrossAttnProcessor(),
|
||||
)
|
||||
self.cross_attn_norm = (
|
||||
nn.LayerNorm(dim, elementwise_affine=True, eps=eps) if cross_attn_norm else nn.Identity()
|
||||
)
|
||||
self.cross_attn_norm_c = (
|
||||
nn.LayerNorm(cond_dim, elementwise_affine=True, eps=eps) if cross_attn_norm else nn.Identity()
|
||||
)
|
||||
self.ffn = AudioDiTFeedForward(dim=dim, mult=ff_mult, dropout=dropout, bias=bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep_embed: torch.Tensor,
|
||||
cond: torch.Tensor,
|
||||
mask: torch.BoolTensor | None = None,
|
||||
cond_mask: torch.BoolTensor | None = None,
|
||||
rope: tuple | None = None,
|
||||
cond_rope: tuple | None = None,
|
||||
adaln_global_out: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if self.adaln_type == "local" and adaln_global_out is None:
|
||||
if self.adaln_use_text_cond:
|
||||
denom = cond_mask.sum(1, keepdim=True).clamp(min=1).to(cond.dtype)
|
||||
cond_mean = cond.sum(1) / denom
|
||||
norm_cond = timestep_embed + cond_mean
|
||||
else:
|
||||
norm_cond = timestep_embed
|
||||
adaln_out = self.adaln_mlp(norm_cond)
|
||||
gate_sa, scale_sa, shift_sa, gate_ffn, scale_ffn, shift_ffn = torch.chunk(adaln_out, 6, dim=-1)
|
||||
else:
|
||||
adaln_out = adaln_global_out + self.adaln_scale_shift.unsqueeze(0)
|
||||
gate_sa, scale_sa, shift_sa, gate_ffn, scale_ffn, shift_ffn = torch.chunk(adaln_out, 6, dim=-1)
|
||||
|
||||
norm_hidden_states = F.layer_norm(hidden_states.float(), (hidden_states.shape[-1],), eps=1e-6).type_as(
|
||||
hidden_states
|
||||
)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_sa[:, None]) + shift_sa[:, None]
|
||||
attn_output = self.self_attn(
|
||||
norm_hidden_states,
|
||||
attention_mask=mask,
|
||||
audio_rotary_emb=rope,
|
||||
)
|
||||
hidden_states = hidden_states + gate_sa.unsqueeze(1) * attn_output
|
||||
|
||||
if self.use_cross_attn:
|
||||
cross_output = self.cross_attn(
|
||||
hidden_states=self.cross_attn_norm(hidden_states),
|
||||
encoder_hidden_states=self.cross_attn_norm_c(cond),
|
||||
post_attention_mask=mask,
|
||||
attention_mask=cond_mask,
|
||||
audio_rotary_emb=rope,
|
||||
prompt_rotary_emb=cond_rope,
|
||||
)
|
||||
hidden_states = hidden_states + cross_output
|
||||
|
||||
norm_hidden_states = F.layer_norm(hidden_states.float(), (hidden_states.shape[-1],), eps=1e-6).type_as(
|
||||
hidden_states
|
||||
)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_ffn[:, None]) + shift_ffn[:, None]
|
||||
ff_output = self.ffn(norm_hidden_states)
|
||||
hidden_states = hidden_states + gate_ffn.unsqueeze(1) * ff_output
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LongCatAudioDiTTransformer(ModelMixin, ConfigMixin):
|
||||
_supports_gradient_checkpointing = False
|
||||
_repeated_blocks = ["AudioDiTBlock"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
dit_dim: int = 1536,
|
||||
dit_depth: int = 24,
|
||||
dit_heads: int = 24,
|
||||
dit_text_dim: int = 768,
|
||||
latent_dim: int = 64,
|
||||
dropout: float = 0.0,
|
||||
bias: bool = True,
|
||||
cross_attn: bool = True,
|
||||
adaln_type: str = "global",
|
||||
adaln_use_text_cond: bool = True,
|
||||
long_skip: bool = True,
|
||||
text_conv: bool = True,
|
||||
qk_norm: bool = True,
|
||||
cross_attn_norm: bool = False,
|
||||
eps: float = 1e-6,
|
||||
use_latent_condition: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
dim = dit_dim
|
||||
dim_head = dim // dit_heads
|
||||
self.time_embed = AudioDiTTimestepEmbedding(dim)
|
||||
self.input_embed = AudioDiTEmbedder(latent_dim, dim)
|
||||
self.text_embed = AudioDiTEmbedder(dit_text_dim, dim)
|
||||
self.rotary_embed = AudioDiTRotaryEmbedding(dim_head, 2048, base=100000.0)
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
AudioDiTBlock(
|
||||
dim=dim,
|
||||
cond_dim=dim,
|
||||
heads=dit_heads,
|
||||
dim_head=dim_head,
|
||||
dropout=dropout,
|
||||
bias=bias,
|
||||
qk_norm=qk_norm,
|
||||
eps=eps,
|
||||
cross_attn=cross_attn,
|
||||
cross_attn_norm=cross_attn_norm,
|
||||
adaln_type=adaln_type,
|
||||
adaln_use_text_cond=adaln_use_text_cond,
|
||||
ff_mult=4.0,
|
||||
)
|
||||
for _ in range(dit_depth)
|
||||
]
|
||||
)
|
||||
self.norm_out = AudioDiTAdaLayerNormZeroFinal(dim, bias=bias, eps=eps)
|
||||
self.proj_out = nn.Linear(dim, latent_dim)
|
||||
if adaln_type == "global":
|
||||
self.adaln_global_mlp = AudioDiTAdaLNMLP(dim, dim * 6, bias=True)
|
||||
self.text_conv = text_conv
|
||||
if text_conv:
|
||||
self.text_conv_layer = nn.Sequential(
|
||||
*[AudioDiTConvNeXtV2Block(dim, dim * 2, bias=bias, eps=eps) for _ in range(4)]
|
||||
)
|
||||
self.use_latent_condition = use_latent_condition
|
||||
if use_latent_condition:
|
||||
self.latent_embed = AudioDiTEmbedder(latent_dim, dim)
|
||||
self.latent_cond_embedder = AudioDiTEmbedder(dim * 2, dim)
|
||||
self._initialize_weights(bias=bias)
|
||||
|
||||
def _initialize_weights(self, bias: bool = True):
|
||||
if self.config.adaln_type == "local":
|
||||
for block in self.blocks:
|
||||
nn.init.constant_(block.adaln_mlp.mlp[-1].weight, 0)
|
||||
if bias:
|
||||
nn.init.constant_(block.adaln_mlp.mlp[-1].bias, 0)
|
||||
elif self.config.adaln_type == "global":
|
||||
nn.init.constant_(self.adaln_global_mlp.mlp[-1].weight, 0)
|
||||
if bias:
|
||||
nn.init.constant_(self.adaln_global_mlp.mlp[-1].bias, 0)
|
||||
nn.init.constant_(self.norm_out.linear.weight, 0)
|
||||
nn.init.constant_(self.proj_out.weight, 0)
|
||||
if bias:
|
||||
nn.init.constant_(self.norm_out.linear.bias, 0)
|
||||
nn.init.constant_(self.proj_out.bias, 0)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_attention_mask: torch.BoolTensor,
|
||||
timestep: torch.Tensor,
|
||||
attention_mask: torch.BoolTensor | None = None,
|
||||
latent_cond: torch.Tensor | None = None,
|
||||
return_dict: bool = True,
|
||||
) -> LongCatAudioDiTTransformerOutput | tuple[torch.Tensor]:
|
||||
dtype = hidden_states.dtype
|
||||
encoder_hidden_states = encoder_hidden_states.to(dtype)
|
||||
timestep = timestep.to(dtype)
|
||||
batch_size = hidden_states.shape[0]
|
||||
if timestep.ndim == 0:
|
||||
timestep = timestep.repeat(batch_size)
|
||||
timestep_embed = self.time_embed(timestep)
|
||||
text_mask = encoder_attention_mask.bool()
|
||||
encoder_hidden_states = self.text_embed(encoder_hidden_states, text_mask)
|
||||
if self.text_conv:
|
||||
encoder_hidden_states = self.text_conv_layer(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.masked_fill(text_mask.logical_not().unsqueeze(-1), 0.0)
|
||||
hidden_states = self.input_embed(hidden_states, attention_mask)
|
||||
if self.use_latent_condition and latent_cond is not None:
|
||||
latent_cond = self.latent_embed(latent_cond.to(hidden_states.dtype), attention_mask)
|
||||
hidden_states = self.latent_cond_embedder(torch.cat([hidden_states, latent_cond], dim=-1))
|
||||
residual = hidden_states.clone() if self.config.long_skip else None
|
||||
rope = self.rotary_embed(hidden_states, hidden_states.shape[1])
|
||||
cond_rope = self.rotary_embed(encoder_hidden_states, encoder_hidden_states.shape[1])
|
||||
if self.config.adaln_type == "global":
|
||||
if self.config.adaln_use_text_cond:
|
||||
text_len = text_mask.sum(1).clamp(min=1).to(encoder_hidden_states.dtype)
|
||||
text_mean = encoder_hidden_states.sum(1) / text_len.unsqueeze(1)
|
||||
norm_cond = timestep_embed + text_mean
|
||||
else:
|
||||
norm_cond = timestep_embed
|
||||
adaln_global_out = self.adaln_global_mlp(norm_cond)
|
||||
for block in self.blocks:
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
timestep_embed=timestep_embed,
|
||||
cond=encoder_hidden_states,
|
||||
mask=attention_mask,
|
||||
cond_mask=text_mask,
|
||||
rope=rope,
|
||||
cond_rope=cond_rope,
|
||||
adaln_global_out=adaln_global_out,
|
||||
)
|
||||
else:
|
||||
norm_cond = timestep_embed
|
||||
for block in self.blocks:
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
timestep_embed=timestep_embed,
|
||||
cond=encoder_hidden_states,
|
||||
mask=attention_mask,
|
||||
cond_mask=text_mask,
|
||||
rope=rope,
|
||||
cond_rope=cond_rope,
|
||||
)
|
||||
if self.config.long_skip:
|
||||
hidden_states = hidden_states + residual
|
||||
hidden_states = self.norm_out(hidden_states, norm_cond)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
if attention_mask is not None:
|
||||
hidden_states = hidden_states * attention_mask.unsqueeze(-1).to(hidden_states.dtype)
|
||||
if not return_dict:
|
||||
return (hidden_states,)
|
||||
return LongCatAudioDiTTransformerOutput(sample=hidden_states)
|
||||
@@ -777,7 +777,8 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
|
||||
|
||||
# Pad token
|
||||
feats_cat = torch.cat(feats, dim=0)
|
||||
feats_cat[torch.cat(inner_pad_mask)] = pad_token
|
||||
mask = torch.cat(inner_pad_mask).unsqueeze(-1)
|
||||
feats_cat = torch.where(mask, pad_token, feats_cat)
|
||||
feats = list(feats_cat.split(item_seqlens, dim=0))
|
||||
|
||||
# RoPE
|
||||
|
||||
@@ -326,6 +326,7 @@ else:
|
||||
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
|
||||
_import_structure["lucy"] = ["LucyEditPipeline"]
|
||||
_import_structure["longcat_image"] = ["LongCatImagePipeline", "LongCatImageEditPipeline"]
|
||||
_import_structure["longcat_audio_dit"] = ["LongCatAudioDiTPipeline"]
|
||||
_import_structure["marigold"].extend(
|
||||
[
|
||||
"MarigoldDepthPipeline",
|
||||
@@ -753,6 +754,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
)
|
||||
from .llada2 import LLaDA2Pipeline, LLaDA2PipelineOutput
|
||||
from .longcat_audio_dit import LongCatAudioDiTPipeline
|
||||
from .longcat_image import LongCatImageEditPipeline, LongCatImagePipeline
|
||||
from .ltx import (
|
||||
LTXConditionPipeline,
|
||||
|
||||
40
src/diffusers/pipelines/longcat_audio_dit/__init__.py
Normal file
40
src/diffusers/pipelines/longcat_audio_dit/__init__.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa: F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_longcat_audio_dit"] = ["LongCatAudioDiTPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_longcat_audio_dit import LongCatAudioDiTPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -0,0 +1,332 @@
|
||||
# Copyright 2026 MeiTuan LongCat-AudioDiT Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Adapted from the LongCat-AudioDiT reference implementation:
|
||||
# https://github.com/meituan-longcat/LongCat-AudioDiT
|
||||
|
||||
import re
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import PreTrainedTokenizerBase, UMT5EncoderModel
|
||||
|
||||
from ...models import LongCatAudioDiTTransformer, LongCatAudioDiTVae
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _lens_to_mask(lengths: torch.Tensor, length: int | None = None) -> torch.BoolTensor:
|
||||
if length is None:
|
||||
length = int(lengths.amax().item())
|
||||
seq = torch.arange(length, device=lengths.device)
|
||||
return seq[None, :] < lengths[:, None]
|
||||
|
||||
|
||||
def _normalize_text(text: str) -> str:
|
||||
text = text.lower()
|
||||
text = re.sub(r'["“”‘’]', " ", text)
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
def _approx_duration_from_text(text: str | list[str], max_duration: float = 30.0) -> float:
|
||||
if not text:
|
||||
return 0.0
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
|
||||
en_dur_per_char = 0.082
|
||||
zh_dur_per_char = 0.21
|
||||
durations = []
|
||||
for prompt in text:
|
||||
prompt = re.sub(r"\s+", "", prompt)
|
||||
num_zh = num_en = num_other = 0
|
||||
for char in prompt:
|
||||
if "一" <= char <= "鿿":
|
||||
num_zh += 1
|
||||
elif char.isalpha():
|
||||
num_en += 1
|
||||
else:
|
||||
num_other += 1
|
||||
if num_zh > num_en:
|
||||
num_zh += num_other
|
||||
else:
|
||||
num_en += num_other
|
||||
durations.append(num_zh * zh_dur_per_char + num_en * en_dur_per_char)
|
||||
return min(max_duration, max(durations)) if durations else 0.0
|
||||
|
||||
|
||||
class LongCatAudioDiTPipeline(DiffusionPipeline):
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: LongCatAudioDiTVae,
|
||||
text_encoder: UMT5EncoderModel,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
transformer: LongCatAudioDiTTransformer,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
if not isinstance(scheduler, FlowMatchEulerDiscreteScheduler):
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(shift=1.0, invert_sigmas=True)
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.sample_rate = getattr(vae.config, "sample_rate", 24000)
|
||||
self.vae_scale_factor = getattr(vae.config, "downsampling_ratio", 2048)
|
||||
self.latent_dim = getattr(transformer.config, "latent_dim", 64)
|
||||
self.max_wav_duration = 30.0
|
||||
self.text_norm_feat = True
|
||||
self.text_add_embed = True
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
def encode_prompt(self, prompt: str | list[str], device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
model_max_length = getattr(self.tokenizer, "model_max_length", 512)
|
||||
if not isinstance(model_max_length, int) or model_max_length <= 0 or model_max_length > 32768:
|
||||
model_max_length = 512
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="longest",
|
||||
truncation=True,
|
||||
max_length=model_max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
input_ids = text_inputs.input_ids.to(device)
|
||||
attention_mask = text_inputs.attention_mask.to(device)
|
||||
with torch.no_grad():
|
||||
output = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
|
||||
prompt_embeds = output.last_hidden_state
|
||||
if self.text_norm_feat:
|
||||
prompt_embeds = F.layer_norm(prompt_embeds, (prompt_embeds.shape[-1],), eps=1e-6)
|
||||
if self.text_add_embed and getattr(output, "hidden_states", None):
|
||||
first_hidden = output.hidden_states[0]
|
||||
if self.text_norm_feat:
|
||||
first_hidden = F.layer_norm(first_hidden, (first_hidden.shape[-1],), eps=1e-6)
|
||||
prompt_embeds = prompt_embeds + first_hidden
|
||||
lengths = attention_mask.sum(dim=1).to(device)
|
||||
return prompt_embeds, lengths
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size: int,
|
||||
duration: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
generator: torch.Generator | list[torch.Generator] | None = None,
|
||||
latents: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if latents is not None:
|
||||
if latents.ndim != 3:
|
||||
raise ValueError(
|
||||
f"`latents` must have shape (batch_size, duration, latent_dim), but got {tuple(latents.shape)}."
|
||||
)
|
||||
if latents.shape[0] != batch_size:
|
||||
raise ValueError(f"`latents` must have batch size {batch_size}, but got {latents.shape[0]}.")
|
||||
if latents.shape[2] != self.latent_dim:
|
||||
raise ValueError(f"`latents` must have latent_dim {self.latent_dim}, but got {latents.shape[2]}.")
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"Expected {batch_size} generators for batch size {batch_size}, but got {len(generator)}."
|
||||
)
|
||||
|
||||
return randn_tensor((batch_size, duration, self.latent_dim), generator=generator, device=device, dtype=dtype)
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt: list[str],
|
||||
negative_prompt: str | list[str] | None,
|
||||
output_type: str,
|
||||
callback_on_step_end_tensor_inputs: list[str] | None = None,
|
||||
) -> None:
|
||||
if len(prompt) == 0:
|
||||
raise ValueError("`prompt` must contain at least one prompt.")
|
||||
|
||||
if output_type not in {"np", "pt", "latent"}:
|
||||
raise ValueError(f"Unsupported output_type: {output_type}")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found "
|
||||
f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if negative_prompt is not None and not isinstance(negative_prompt, str):
|
||||
negative_prompt = list(negative_prompt)
|
||||
if len(negative_prompt) != len(prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt` must have batch size {len(prompt)}, but got {len(negative_prompt)} prompts."
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str | list[str],
|
||||
negative_prompt: str | list[str] | None = None,
|
||||
audio_duration_s: float | None = None,
|
||||
latents: torch.Tensor | None = None,
|
||||
num_inference_steps: int = 16,
|
||||
guidance_scale: float = 4.0,
|
||||
generator: torch.Generator | list[torch.Generator] | None = None,
|
||||
output_type: str = "np",
|
||||
return_dict: bool = True,
|
||||
callback_on_step_end: Callable[[int, int], None] | None = None,
|
||||
callback_on_step_end_tensor_inputs: list[str] = ["latents"],
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `list[str]`): Prompt or prompts that guide audio generation.
|
||||
negative_prompt (`str` or `list[str]`, *optional*): Negative prompt(s) for classifier-free guidance.
|
||||
audio_duration_s (`float`, *optional*):
|
||||
Target audio duration in seconds. Ignored when `latents` is provided.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents of shape `(batch_size, duration, latent_dim)`.
|
||||
num_inference_steps (`int`, defaults to 16): Number of denoising steps.
|
||||
guidance_scale (`float`, defaults to 4.0): Guidance scale for classifier-free guidance.
|
||||
generator (`torch.Generator` or `list[torch.Generator]`, *optional*): Random generator(s).
|
||||
output_type (`str`, defaults to `"np"`): Output format: `"np"`, `"pt"`, or `"latent"`.
|
||||
return_dict (`bool`, defaults to `True`): Whether to return `AudioPipelineOutput`.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function called at the end of each denoising step with the pipeline, step index, timestep, and tensor
|
||||
inputs specified by `callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`list`, defaults to `["latents"]`):
|
||||
Tensor inputs passed to `callback_on_step_end`.
|
||||
"""
|
||||
if prompt is None:
|
||||
prompt = []
|
||||
elif isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
else:
|
||||
prompt = list(prompt)
|
||||
self.check_inputs(prompt, negative_prompt, output_type, callback_on_step_end_tensor_inputs)
|
||||
batch_size = len(prompt)
|
||||
self._guidance_scale = guidance_scale
|
||||
|
||||
device = self._execution_device
|
||||
normalized_prompts = [_normalize_text(text) for text in prompt]
|
||||
if latents is not None:
|
||||
duration = latents.shape[1]
|
||||
elif audio_duration_s is not None:
|
||||
duration = int(audio_duration_s * self.sample_rate // self.vae_scale_factor)
|
||||
else:
|
||||
duration = int(_approx_duration_from_text(normalized_prompts) * self.sample_rate // self.vae_scale_factor)
|
||||
max_duration = int(self.max_wav_duration * self.sample_rate // self.vae_scale_factor)
|
||||
if latents is None:
|
||||
duration = max(1, min(duration, max_duration))
|
||||
|
||||
prompt_embeds, prompt_embeds_len = self.encode_prompt(normalized_prompts, device)
|
||||
duration_tensor = torch.full((batch_size,), duration, device=device, dtype=torch.long)
|
||||
mask = _lens_to_mask(duration_tensor)
|
||||
text_mask = _lens_to_mask(prompt_embeds_len, length=prompt_embeds.shape[1])
|
||||
|
||||
if negative_prompt is None:
|
||||
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
||||
negative_prompt_embeds_len = prompt_embeds_len
|
||||
negative_prompt_embeds_mask = text_mask
|
||||
else:
|
||||
if isinstance(negative_prompt, str):
|
||||
negative_prompt = [negative_prompt] * batch_size
|
||||
else:
|
||||
negative_prompt = list(negative_prompt)
|
||||
negative_prompt_embeds, negative_prompt_embeds_len = self.encode_prompt(negative_prompt, device)
|
||||
negative_prompt_embeds_mask = _lens_to_mask(
|
||||
negative_prompt_embeds_len, length=negative_prompt_embeds.shape[1]
|
||||
)
|
||||
|
||||
latent_cond = torch.zeros(batch_size, duration, self.latent_dim, device=device, dtype=prompt_embeds.dtype)
|
||||
latents = self.prepare_latents(
|
||||
batch_size, duration, device, prompt_embeds.dtype, generator=generator, latents=latents
|
||||
)
|
||||
if num_inference_steps < 1:
|
||||
raise ValueError("num_inference_steps must be a positive integer.")
|
||||
|
||||
sigmas = torch.linspace(1.0, 1.0 / num_inference_steps, num_inference_steps, dtype=torch.float32).tolist()
|
||||
self.scheduler.set_timesteps(sigmas=sigmas, device=device)
|
||||
self.scheduler.set_begin_index(0)
|
||||
timesteps = self.scheduler.timesteps
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
curr_t = (
|
||||
(t / self.scheduler.config.num_train_timesteps).expand(batch_size).to(dtype=prompt_embeds.dtype)
|
||||
)
|
||||
pred = self.transformer(
|
||||
hidden_states=latents,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_attention_mask=text_mask,
|
||||
timestep=curr_t,
|
||||
attention_mask=mask,
|
||||
latent_cond=latent_cond,
|
||||
).sample
|
||||
if self.guidance_scale > 1.0:
|
||||
null_pred = self.transformer(
|
||||
hidden_states=latents,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
encoder_attention_mask=negative_prompt_embeds_mask,
|
||||
timestep=curr_t,
|
||||
attention_mask=mask,
|
||||
latent_cond=latent_cond,
|
||||
).sample
|
||||
pred = null_pred + (pred - null_pred) * self.guidance_scale
|
||||
latents = self.scheduler.step(pred, t, latents, return_dict=False)[0]
|
||||
progress_bar.update()
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
if output_type == "latent":
|
||||
waveform = latents
|
||||
else:
|
||||
waveform = self.vae.decode(latents.permute(0, 2, 1)).sample
|
||||
if output_type == "np":
|
||||
waveform = waveform.cpu().float().numpy()
|
||||
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (waveform,)
|
||||
return AudioPipelineOutput(audios=waveform)
|
||||
@@ -486,6 +486,15 @@ class ZImagePipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMix
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# We set the index here to remove DtoH sync, helpful especially during compilation.
|
||||
# Check out more details here: https://github.com/huggingface/diffusers/pull/11696
|
||||
self.scheduler.set_begin_index(0)
|
||||
|
||||
if self.do_classifier_free_guidance and self._cfg_truncation is not None and float(self._cfg_truncation) <= 1:
|
||||
_precomputed_t_norms = ((1000 - timesteps.float()) / 1000).tolist()
|
||||
else:
|
||||
_precomputed_t_norms = None
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
@@ -495,17 +504,9 @@ class ZImagePipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMix
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0])
|
||||
timestep = (1000 - timestep) / 1000
|
||||
# Normalized time for time-aware config (0 at start, 1 at end)
|
||||
t_norm = timestep[0].item()
|
||||
|
||||
# Handle cfg truncation
|
||||
current_guidance_scale = self.guidance_scale
|
||||
if (
|
||||
self.do_classifier_free_guidance
|
||||
and self._cfg_truncation is not None
|
||||
and float(self._cfg_truncation) <= 1
|
||||
):
|
||||
if t_norm > self._cfg_truncation:
|
||||
if _precomputed_t_norms is not None:
|
||||
if _precomputed_t_norms[i] > self._cfg_truncation:
|
||||
current_guidance_scale = 0.0
|
||||
|
||||
# Run CFG only if configured AND scale is non-zero
|
||||
|
||||
@@ -1395,6 +1395,36 @@ class LatteTransformer3DModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class LongCatAudioDiTTransformer(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class LongCatAudioDiTVae(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class LongCatImageTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -2297,6 +2297,21 @@ class LLaDA2PipelineOutput(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LongCatAudioDiTPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LongCatImageEditPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -0,0 +1,121 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers import LongCatAudioDiTTransformer
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
MemoryTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class LongCatAudioDiTTransformerTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
@property
|
||||
def model_class(self):
|
||||
return LongCatAudioDiTTransformer
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple[int, ...]:
|
||||
return (16, 8)
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict[str, int | bool | float | str]:
|
||||
return {
|
||||
"dit_dim": 64,
|
||||
"dit_depth": 2,
|
||||
"dit_heads": 4,
|
||||
"dit_text_dim": 32,
|
||||
"latent_dim": 8,
|
||||
"text_conv": False,
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
batch_size = 1
|
||||
sequence_length = 16
|
||||
encoder_sequence_length = 10
|
||||
latent_dim = 8
|
||||
text_dim = 32
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, latent_dim), generator=self.generator, device=torch_device
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, encoder_sequence_length, text_dim), generator=self.generator, device=torch_device
|
||||
),
|
||||
"encoder_attention_mask": torch.ones(
|
||||
batch_size, encoder_sequence_length, dtype=torch.bool, device=torch_device
|
||||
),
|
||||
"attention_mask": torch.ones(batch_size, sequence_length, dtype=torch.bool, device=torch_device),
|
||||
"timestep": torch.ones(batch_size, device=torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestLongCatAudioDiTTransformer(LongCatAudioDiTTransformerTesterConfig, ModelTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
class TestLongCatAudioDiTTransformerMemory(LongCatAudioDiTTransformerTesterConfig, MemoryTesterMixin):
|
||||
def test_layerwise_casting_memory(self):
|
||||
pytest.skip(
|
||||
"LongCatAudioDiTTransformer tiny test config does not provide stable layerwise casting peak memory "
|
||||
"coverage."
|
||||
)
|
||||
|
||||
|
||||
class TestLongCatAudioDiTTransformerCompile(LongCatAudioDiTTransformerTesterConfig, TorchCompileTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
class TestLongCatAudioDiTTransformerAttention(LongCatAudioDiTTransformerTesterConfig, AttentionTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
def test_longcat_audio_attention_uses_standard_self_attn_kwargs():
|
||||
from diffusers.models.transformers.transformer_longcat_audio_dit import AudioDiTAttention
|
||||
|
||||
attn = AudioDiTAttention(q_dim=4, kv_dim=None, heads=1, dim_head=4, dropout=0.0, bias=False)
|
||||
|
||||
eye = torch.eye(4)
|
||||
with torch.no_grad():
|
||||
attn.to_q.weight.copy_(eye)
|
||||
attn.to_k.weight.copy_(eye)
|
||||
attn.to_v.weight.copy_(eye)
|
||||
attn.to_out[0].weight.copy_(eye)
|
||||
|
||||
hidden_states = torch.tensor([[[1.0, 0.0, 0.0, 0.0], [0.5, 0.5, 0.5, 0.5]]])
|
||||
attention_mask = torch.tensor([[True, False]])
|
||||
|
||||
output = attn(hidden_states=hidden_states, attention_mask=attention_mask)
|
||||
|
||||
assert torch.allclose(output[:, 1], torch.zeros_like(output[:, 1]))
|
||||
0
tests/pipelines/longcat_audio_dit/__init__.py
Normal file
0
tests/pipelines/longcat_audio_dit/__init__.py
Normal file
225
tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py
Normal file
225
tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py
Normal file
@@ -0,0 +1,225 @@
|
||||
# Copyright 2026 The HuggingFace Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, UMT5Config, UMT5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
LongCatAudioDiTPipeline,
|
||||
LongCatAudioDiTTransformer,
|
||||
LongCatAudioDiTVae,
|
||||
)
|
||||
|
||||
from ...testing_utils import enable_full_determinism, require_torch_accelerator, slow, torch_device
|
||||
from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class LongCatAudioDiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = LongCatAudioDiTPipeline
|
||||
params = (
|
||||
TEXT_TO_AUDIO_PARAMS
|
||||
- {"audio_length_in_s", "prompt_embeds", "negative_prompt_embeds", "cross_attention_kwargs"}
|
||||
) | {"audio_duration_s"}
|
||||
batch_params = TEXT_TO_AUDIO_BATCH_PARAMS
|
||||
required_optional_params = PipelineTesterMixin.required_optional_params - {"num_images_per_prompt"}
|
||||
test_attention_slicing = False
|
||||
test_xformers_attention = False
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder = UMT5EncoderModel(
|
||||
UMT5Config(d_model=32, num_layers=1, num_heads=4, d_ff=64, vocab_size=tokenizer.vocab_size)
|
||||
)
|
||||
transformer = LongCatAudioDiTTransformer(
|
||||
dit_dim=64,
|
||||
dit_depth=2,
|
||||
dit_heads=4,
|
||||
dit_text_dim=32,
|
||||
latent_dim=8,
|
||||
text_conv=False,
|
||||
)
|
||||
vae = LongCatAudioDiTVae(
|
||||
in_channels=1,
|
||||
channels=16,
|
||||
c_mults=[1, 2],
|
||||
strides=[2],
|
||||
latent_dim=8,
|
||||
encoder_latent_dim=16,
|
||||
downsampling_ratio=2,
|
||||
sample_rate=24000,
|
||||
)
|
||||
|
||||
return {
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"transformer": transformer,
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0, prompt="soft ocean ambience"):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"audio_duration_s": 0.1,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 1.0,
|
||||
"generator": generator,
|
||||
"output_type": "pt",
|
||||
}
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
pipe = self.pipeline_class(**self.get_dummy_components())
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
output = pipe(**self.get_dummy_inputs(device)).audios
|
||||
|
||||
self.assertEqual(output.ndim, 3)
|
||||
self.assertEqual(output.shape[0], 1)
|
||||
self.assertEqual(output.shape[1], 1)
|
||||
self.assertGreater(output.shape[-1], 0)
|
||||
|
||||
def test_save_load_local(self):
|
||||
import tempfile
|
||||
|
||||
device = "cpu"
|
||||
pipe = self.pipeline_class(**self.get_dummy_components())
|
||||
pipe.to(device)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipe.save_pretrained(tmp_dir)
|
||||
reloaded = self.pipeline_class.from_pretrained(tmp_dir, local_files_only=True)
|
||||
output = reloaded(**self.get_dummy_inputs(device, seed=0)).audios
|
||||
|
||||
self.assertIsInstance(reloaded, LongCatAudioDiTPipeline)
|
||||
self.assertEqual(output.ndim, 3)
|
||||
self.assertGreater(output.shape[-1], 0)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
|
||||
|
||||
def test_model_cpu_offload_forward_pass(self):
|
||||
self.skipTest(
|
||||
"LongCatAudioDiTPipeline offload coverage is not ready for the standard PipelineTesterMixin test."
|
||||
)
|
||||
|
||||
def test_cpu_offload_forward_pass_twice(self):
|
||||
self.skipTest(
|
||||
"LongCatAudioDiTPipeline offload coverage is not ready for the standard PipelineTesterMixin test."
|
||||
)
|
||||
|
||||
def test_sequential_cpu_offload_forward_pass(self):
|
||||
self.skipTest(
|
||||
"LongCatAudioDiTPipeline uses `torch.nn.utils.weight_norm`, which is not compatible with "
|
||||
"sequential offloading."
|
||||
)
|
||||
|
||||
def test_sequential_offload_forward_pass_twice(self):
|
||||
self.skipTest(
|
||||
"LongCatAudioDiTPipeline uses `torch.nn.utils.weight_norm`, which is not compatible with "
|
||||
"sequential offloading."
|
||||
)
|
||||
|
||||
def test_pipeline_level_group_offloading_inference(self):
|
||||
self.skipTest(
|
||||
"LongCatAudioDiTPipeline group offloading coverage is not ready for the standard PipelineTesterMixin test."
|
||||
)
|
||||
|
||||
def test_num_images_per_prompt(self):
|
||||
self.skipTest("LongCatAudioDiTPipeline does not support num_images_per_prompt.")
|
||||
|
||||
def test_encode_prompt_works_in_isolation(self):
|
||||
self.skipTest("LongCatAudioDiTPipeline.encode_prompt has a custom signature.")
|
||||
|
||||
def test_uniform_flow_match_scheduler_grid_matches_manual_updates(self):
|
||||
num_inference_steps = 6
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(shift=1.0, invert_sigmas=True)
|
||||
sigmas = torch.linspace(1.0, 1.0 / num_inference_steps, num_inference_steps, dtype=torch.float32).tolist()
|
||||
scheduler.set_timesteps(sigmas=sigmas, device="cpu")
|
||||
|
||||
expected_grid = torch.linspace(0, 1, num_inference_steps + 1, dtype=torch.float32)
|
||||
actual_timesteps = scheduler.timesteps / scheduler.config.num_train_timesteps
|
||||
self.assertTrue(torch.allclose(actual_timesteps, expected_grid[:-1], atol=1e-6, rtol=0))
|
||||
|
||||
sample = torch.zeros(1, 2, 3)
|
||||
model_output = torch.ones_like(sample)
|
||||
expected = sample.clone()
|
||||
for t0, t1, scheduler_t in zip(expected_grid[:-1], expected_grid[1:], scheduler.timesteps):
|
||||
expected = expected + model_output * (t1 - t0)
|
||||
sample = scheduler.step(model_output, scheduler_t, sample, return_dict=False)[0]
|
||||
|
||||
self.assertTrue(torch.allclose(sample, expected, atol=1e-6, rtol=0))
|
||||
|
||||
|
||||
def test_longcat_audio_top_level_imports():
|
||||
assert LongCatAudioDiTPipeline is not None
|
||||
assert LongCatAudioDiTTransformer is not None
|
||||
assert LongCatAudioDiTVae is not None
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_accelerator
|
||||
class LongCatAudioDiTPipelineSlowTests(unittest.TestCase):
|
||||
pipeline_class = LongCatAudioDiTPipeline
|
||||
|
||||
def test_longcat_audio_pipeline_from_pretrained_real_local_weights(self):
|
||||
model_path = Path(
|
||||
os.getenv("LONGCAT_AUDIO_DIT_MODEL_PATH", "/data/models/meituan-longcat/LongCat-AudioDiT-1B")
|
||||
)
|
||||
tokenizer_path_env = os.getenv("LONGCAT_AUDIO_DIT_TOKENIZER_PATH")
|
||||
if tokenizer_path_env is None:
|
||||
raise unittest.SkipTest("LONGCAT_AUDIO_DIT_TOKENIZER_PATH is not set")
|
||||
tokenizer_path = Path(tokenizer_path_env)
|
||||
|
||||
if not model_path.exists():
|
||||
raise unittest.SkipTest(f"LongCat-AudioDiT model path not found: {model_path}")
|
||||
if not tokenizer_path.exists():
|
||||
raise unittest.SkipTest(f"LongCat-AudioDiT tokenizer path not found: {tokenizer_path}")
|
||||
|
||||
pipe = LongCatAudioDiTPipeline.from_pretrained(
|
||||
model_path,
|
||||
tokenizer=tokenizer_path,
|
||||
torch_dtype=torch.float16,
|
||||
local_files_only=True,
|
||||
)
|
||||
pipe = pipe.to(torch_device)
|
||||
|
||||
result = pipe(
|
||||
prompt="A calm ocean wave ambience with soft wind in the background.",
|
||||
audio_duration_s=2.0,
|
||||
num_inference_steps=2,
|
||||
guidance_scale=4.0,
|
||||
output_type="pt",
|
||||
)
|
||||
|
||||
assert result.audios.ndim == 3
|
||||
assert result.audios.shape[0] == 1
|
||||
assert result.audios.shape[1] == 1
|
||||
assert result.audios.shape[-1] > 0
|
||||
Reference in New Issue
Block a user