mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-10 22:44:38 +08:00
Compare commits
4 Commits
enable-cp-
...
flux-rope-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
34e607c757 | ||
|
|
6cc6c130cf | ||
|
|
666a3d9448 | ||
|
|
b89b5d1338 |
14
check_rope_batched.py
Normal file
14
check_rope_batched.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from diffusers.models.embeddings import FluxPosEmbed
|
||||
import torch
|
||||
|
||||
batch_size = 4
|
||||
seq_length = 16
|
||||
img_seq_length = 32
|
||||
txt_ids = torch.randn(batch_size, seq_length, 3)
|
||||
img_ids = torch.randn(batch_size, img_seq_length, 3)
|
||||
|
||||
pos_embed = FluxPosEmbed(theta=10000, axes_dim=[4, 4, 8])
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
image_rotary_emb = pos_embed(ids)
|
||||
# image_rotary_emb[0].shape=torch.Size([4, 48, 16]), image_rotary_emb[1].shape=torch.Size([4, 48, 16])
|
||||
print(f"{image_rotary_emb[0].shape=}, {image_rotary_emb[1].shape=}")
|
||||
@@ -1147,32 +1147,38 @@ def get_1d_rotary_pos_embed(
|
||||
"""
|
||||
assert dim % 2 == 0
|
||||
|
||||
if isinstance(pos, int):
|
||||
pos = torch.arange(pos)
|
||||
if isinstance(pos, np.ndarray):
|
||||
pos = torch.from_numpy(pos) # type: ignore # [S]
|
||||
# Handle both batched [B, S] and un-batched [S] inputs
|
||||
if pos.ndim == 1:
|
||||
pos = pos.unsqueeze(0) # Add a batch dimension if missing
|
||||
|
||||
theta = theta * ntk_factor
|
||||
freqs = (
|
||||
1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device) / dim)) / linear_factor
|
||||
) # [D/2]
|
||||
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
|
||||
) # Shape: [D/2]
|
||||
|
||||
# Replace torch.outer with broadcasted multiplication
|
||||
# Old: freqs = torch.outer(pos, freqs) # Shape: [S, D/2]
|
||||
# New: pos is [B, S], freqs is [D/2]. Unsqueeze pos to [B, S, 1] for broadcasting.
|
||||
freqs = pos.unsqueeze(-1) * freqs # Shape: [B, S, D/2]
|
||||
|
||||
is_npu = freqs.device.type == "npu"
|
||||
if is_npu:
|
||||
freqs = freqs.float()
|
||||
|
||||
if use_real and repeat_interleave_real:
|
||||
# flux, hunyuan-dit, cogvideox
|
||||
freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
|
||||
freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
|
||||
# Use dim=-1 for robust interleaving on the feature dimension
|
||||
freqs_cos = freqs.cos().repeat_interleave(2, dim=-1).float() # Shape: [B, S, D]
|
||||
freqs_sin = freqs.sin().repeat_interleave(2, dim=-1).float() # Shape: [B, S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
elif use_real:
|
||||
# stable audio, allegro
|
||||
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
|
||||
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
|
||||
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # Shape: [B, S, D]
|
||||
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # Shape: [B, S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
else:
|
||||
# lumina
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # Shape: [B, S, D/2]
|
||||
return freqs_cis
|
||||
|
||||
|
||||
|
||||
@@ -489,6 +489,11 @@ class FluxPosEmbed(nn.Module):
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
was_unbatched = ids.ndim == 2
|
||||
if was_unbatched:
|
||||
# Add a batch dimension to standardize processing
|
||||
ids = ids.unsqueeze(0)
|
||||
# ids is now expected to be [B, S, n_axes]
|
||||
n_axes = ids.shape[-1]
|
||||
cos_out = []
|
||||
sin_out = []
|
||||
@@ -496,10 +501,11 @@ class FluxPosEmbed(nn.Module):
|
||||
is_mps = ids.device.type == "mps"
|
||||
is_npu = ids.device.type == "npu"
|
||||
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
||||
|
||||
for i in range(n_axes):
|
||||
cos, sin = get_1d_rotary_pos_embed(
|
||||
self.axes_dim[i],
|
||||
pos[:, i],
|
||||
pos[:, :, i], # Correct slicing for batched input
|
||||
theta=self.theta,
|
||||
repeat_interleave_real=True,
|
||||
use_real=True,
|
||||
@@ -507,8 +513,15 @@ class FluxPosEmbed(nn.Module):
|
||||
)
|
||||
cos_out.append(cos)
|
||||
sin_out.append(sin)
|
||||
|
||||
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
|
||||
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
|
||||
|
||||
# Squeeze the batch dim if the original input was unbatched
|
||||
if was_unbatched:
|
||||
freqs_cos = freqs_cos.squeeze(0)
|
||||
freqs_sin = freqs_sin.squeeze(0)
|
||||
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
@@ -685,18 +698,17 @@ class FluxTransformer2DModel(
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
if txt_ids.ndim == 3:
|
||||
logger.warning(
|
||||
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
||||
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
||||
)
|
||||
# logger.warning(
|
||||
# "Passing `txt_ids` 3d torch.Tensor is deprecated."
|
||||
# "Please remove the batch dimension and pass it as a 2d torch Tensor"
|
||||
# )
|
||||
txt_ids = txt_ids[0]
|
||||
if img_ids.ndim == 3:
|
||||
logger.warning(
|
||||
"Passing `img_ids` 3d torch.Tensor is deprecated."
|
||||
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
||||
)
|
||||
# logger.warning(
|
||||
# "Passing `img_ids` 3d torch.Tensor is deprecated."
|
||||
# "Please remove the batch dimension and pass it as a 2d torch Tensor"
|
||||
# )
|
||||
img_ids = img_ids[0]
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user