Compare commits

...

4 Commits

Author SHA1 Message Date
sayakpaul
34e607c757 kay 2025-09-06 11:34:41 +05:30
sayakpaul
6cc6c130cf fixes 2025-07-07 14:22:10 +05:30
sayakpaul
666a3d9448 update 2025-07-07 13:47:12 +05:30
sayakpaul
b89b5d1338 feat: add batching support in Flux RoPE for metaqueries 2025-07-07 09:57:40 +05:30
3 changed files with 53 additions and 21 deletions

14
check_rope_batched.py Normal file
View File

@@ -0,0 +1,14 @@
from diffusers.models.embeddings import FluxPosEmbed
import torch
batch_size = 4
seq_length = 16
img_seq_length = 32
txt_ids = torch.randn(batch_size, seq_length, 3)
img_ids = torch.randn(batch_size, img_seq_length, 3)
pos_embed = FluxPosEmbed(theta=10000, axes_dim=[4, 4, 8])
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = pos_embed(ids)
# image_rotary_emb[0].shape=torch.Size([4, 48, 16]), image_rotary_emb[1].shape=torch.Size([4, 48, 16])
print(f"{image_rotary_emb[0].shape=}, {image_rotary_emb[1].shape=}")

View File

@@ -1147,32 +1147,38 @@ def get_1d_rotary_pos_embed(
"""
assert dim % 2 == 0
if isinstance(pos, int):
pos = torch.arange(pos)
if isinstance(pos, np.ndarray):
pos = torch.from_numpy(pos) # type: ignore # [S]
# Handle both batched [B, S] and un-batched [S] inputs
if pos.ndim == 1:
pos = pos.unsqueeze(0) # Add a batch dimension if missing
theta = theta * ntk_factor
freqs = (
1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device) / dim)) / linear_factor
) # [D/2]
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
) # Shape: [D/2]
# Replace torch.outer with broadcasted multiplication
# Old: freqs = torch.outer(pos, freqs) # Shape: [S, D/2]
# New: pos is [B, S], freqs is [D/2]. Unsqueeze pos to [B, S, 1] for broadcasting.
freqs = pos.unsqueeze(-1) * freqs # Shape: [B, S, D/2]
is_npu = freqs.device.type == "npu"
if is_npu:
freqs = freqs.float()
if use_real and repeat_interleave_real:
# flux, hunyuan-dit, cogvideox
freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
# Use dim=-1 for robust interleaving on the feature dimension
freqs_cos = freqs.cos().repeat_interleave(2, dim=-1).float() # Shape: [B, S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=-1).float() # Shape: [B, S, D]
return freqs_cos, freqs_sin
elif use_real:
# stable audio, allegro
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # Shape: [B, S, D]
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # Shape: [B, S, D]
return freqs_cos, freqs_sin
else:
# lumina
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # Shape: [B, S, D/2]
return freqs_cis

View File

@@ -489,6 +489,11 @@ class FluxPosEmbed(nn.Module):
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
was_unbatched = ids.ndim == 2
if was_unbatched:
# Add a batch dimension to standardize processing
ids = ids.unsqueeze(0)
# ids is now expected to be [B, S, n_axes]
n_axes = ids.shape[-1]
cos_out = []
sin_out = []
@@ -496,10 +501,11 @@ class FluxPosEmbed(nn.Module):
is_mps = ids.device.type == "mps"
is_npu = ids.device.type == "npu"
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i],
pos[:, i],
pos[:, :, i], # Correct slicing for batched input
theta=self.theta,
repeat_interleave_real=True,
use_real=True,
@@ -507,8 +513,15 @@ class FluxPosEmbed(nn.Module):
)
cos_out.append(cos)
sin_out.append(sin)
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
# Squeeze the batch dim if the original input was unbatched
if was_unbatched:
freqs_cos = freqs_cos.squeeze(0)
freqs_sin = freqs_sin.squeeze(0)
return freqs_cos, freqs_sin
@@ -685,18 +698,17 @@ class FluxTransformer2DModel(
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if txt_ids.ndim == 3:
logger.warning(
"Passing `txt_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
# logger.warning(
# "Passing `txt_ids` 3d torch.Tensor is deprecated."
# "Please remove the batch dimension and pass it as a 2d torch Tensor"
# )
txt_ids = txt_ids[0]
if img_ids.ndim == 3:
logger.warning(
"Passing `img_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
# logger.warning(
# "Passing `img_ids` 3d torch.Tensor is deprecated."
# "Please remove the batch dimension and pass it as a 2d torch Tensor"
# )
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)