Merge branch 'huggingface:main' into rae

This commit is contained in:
Ando
2026-03-04 19:21:12 +08:00
committed by GitHub

View File

@@ -1232,22 +1232,49 @@ def main(args):
id_token=args.id_token,
)
def encode_video(video, bar):
bar.update(1)
def encode_video(video):
video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0)
video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
latent_dist = vae.encode(video).latent_dist
return latent_dist
# Distribute video encoding across processes: each process only encodes its own shard
num_videos = len(train_dataset.instance_videos)
num_procs = accelerator.num_processes
local_rank = accelerator.process_index
local_count = len(range(local_rank, num_videos, num_procs))
progress_encode_bar = tqdm(
range(0, len(train_dataset.instance_videos)),
desc="Loading Encode videos",
range(local_count),
desc="Encoding videos",
disable=not accelerator.is_local_main_process,
)
train_dataset.instance_videos = [
encode_video(video, progress_encode_bar) for video in train_dataset.instance_videos
]
encoded_videos = [None] * num_videos
for i, video in enumerate(train_dataset.instance_videos):
if i % num_procs == local_rank:
encoded_videos[i] = encode_video(video)
progress_encode_bar.update(1)
progress_encode_bar.close()
# Broadcast encoded latent distributions so every process has the full set
if num_procs > 1:
import torch.distributed as dist
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
ref_params = next(v for v in encoded_videos if v is not None).parameters
for i in range(num_videos):
src = i % num_procs
if encoded_videos[i] is not None:
params = encoded_videos[i].parameters.contiguous()
else:
params = torch.empty_like(ref_params)
dist.broadcast(params, src=src)
encoded_videos[i] = DiagonalGaussianDistribution(params)
train_dataset.instance_videos = encoded_videos
def collate_fn(examples):
videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples]
prompts = [example["instance_prompt"] for example in examples]