mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-23 21:04:56 +08:00
Compare commits
4 Commits
remove-unn
...
flux-rope-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
34e607c757 | ||
|
|
6cc6c130cf | ||
|
|
666a3d9448 | ||
|
|
b89b5d1338 |
14
check_rope_batched.py
Normal file
14
check_rope_batched.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
from diffusers.models.embeddings import FluxPosEmbed
|
||||||
|
import torch
|
||||||
|
|
||||||
|
batch_size = 4
|
||||||
|
seq_length = 16
|
||||||
|
img_seq_length = 32
|
||||||
|
txt_ids = torch.randn(batch_size, seq_length, 3)
|
||||||
|
img_ids = torch.randn(batch_size, img_seq_length, 3)
|
||||||
|
|
||||||
|
pos_embed = FluxPosEmbed(theta=10000, axes_dim=[4, 4, 8])
|
||||||
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
|
image_rotary_emb = pos_embed(ids)
|
||||||
|
# image_rotary_emb[0].shape=torch.Size([4, 48, 16]), image_rotary_emb[1].shape=torch.Size([4, 48, 16])
|
||||||
|
print(f"{image_rotary_emb[0].shape=}, {image_rotary_emb[1].shape=}")
|
||||||
@@ -1147,32 +1147,38 @@ def get_1d_rotary_pos_embed(
|
|||||||
"""
|
"""
|
||||||
assert dim % 2 == 0
|
assert dim % 2 == 0
|
||||||
|
|
||||||
if isinstance(pos, int):
|
# Handle both batched [B, S] and un-batched [S] inputs
|
||||||
pos = torch.arange(pos)
|
if pos.ndim == 1:
|
||||||
if isinstance(pos, np.ndarray):
|
pos = pos.unsqueeze(0) # Add a batch dimension if missing
|
||||||
pos = torch.from_numpy(pos) # type: ignore # [S]
|
|
||||||
|
|
||||||
theta = theta * ntk_factor
|
theta = theta * ntk_factor
|
||||||
freqs = (
|
freqs = (
|
||||||
1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device) / dim)) / linear_factor
|
1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device) / dim)) / linear_factor
|
||||||
) # [D/2]
|
) # Shape: [D/2]
|
||||||
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
|
|
||||||
|
# Replace torch.outer with broadcasted multiplication
|
||||||
|
# Old: freqs = torch.outer(pos, freqs) # Shape: [S, D/2]
|
||||||
|
# New: pos is [B, S], freqs is [D/2]. Unsqueeze pos to [B, S, 1] for broadcasting.
|
||||||
|
freqs = pos.unsqueeze(-1) * freqs # Shape: [B, S, D/2]
|
||||||
|
|
||||||
is_npu = freqs.device.type == "npu"
|
is_npu = freqs.device.type == "npu"
|
||||||
if is_npu:
|
if is_npu:
|
||||||
freqs = freqs.float()
|
freqs = freqs.float()
|
||||||
|
|
||||||
if use_real and repeat_interleave_real:
|
if use_real and repeat_interleave_real:
|
||||||
# flux, hunyuan-dit, cogvideox
|
# flux, hunyuan-dit, cogvideox
|
||||||
freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
|
# Use dim=-1 for robust interleaving on the feature dimension
|
||||||
freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
|
freqs_cos = freqs.cos().repeat_interleave(2, dim=-1).float() # Shape: [B, S, D]
|
||||||
|
freqs_sin = freqs.sin().repeat_interleave(2, dim=-1).float() # Shape: [B, S, D]
|
||||||
return freqs_cos, freqs_sin
|
return freqs_cos, freqs_sin
|
||||||
elif use_real:
|
elif use_real:
|
||||||
# stable audio, allegro
|
# stable audio, allegro
|
||||||
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
|
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # Shape: [B, S, D]
|
||||||
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
|
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # Shape: [B, S, D]
|
||||||
return freqs_cos, freqs_sin
|
return freqs_cos, freqs_sin
|
||||||
else:
|
else:
|
||||||
# lumina
|
# lumina
|
||||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # Shape: [B, S, D/2]
|
||||||
return freqs_cis
|
return freqs_cis
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -489,6 +489,11 @@ class FluxPosEmbed(nn.Module):
|
|||||||
self.axes_dim = axes_dim
|
self.axes_dim = axes_dim
|
||||||
|
|
||||||
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
was_unbatched = ids.ndim == 2
|
||||||
|
if was_unbatched:
|
||||||
|
# Add a batch dimension to standardize processing
|
||||||
|
ids = ids.unsqueeze(0)
|
||||||
|
# ids is now expected to be [B, S, n_axes]
|
||||||
n_axes = ids.shape[-1]
|
n_axes = ids.shape[-1]
|
||||||
cos_out = []
|
cos_out = []
|
||||||
sin_out = []
|
sin_out = []
|
||||||
@@ -496,10 +501,11 @@ class FluxPosEmbed(nn.Module):
|
|||||||
is_mps = ids.device.type == "mps"
|
is_mps = ids.device.type == "mps"
|
||||||
is_npu = ids.device.type == "npu"
|
is_npu = ids.device.type == "npu"
|
||||||
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
||||||
|
|
||||||
for i in range(n_axes):
|
for i in range(n_axes):
|
||||||
cos, sin = get_1d_rotary_pos_embed(
|
cos, sin = get_1d_rotary_pos_embed(
|
||||||
self.axes_dim[i],
|
self.axes_dim[i],
|
||||||
pos[:, i],
|
pos[:, :, i], # Correct slicing for batched input
|
||||||
theta=self.theta,
|
theta=self.theta,
|
||||||
repeat_interleave_real=True,
|
repeat_interleave_real=True,
|
||||||
use_real=True,
|
use_real=True,
|
||||||
@@ -507,8 +513,15 @@ class FluxPosEmbed(nn.Module):
|
|||||||
)
|
)
|
||||||
cos_out.append(cos)
|
cos_out.append(cos)
|
||||||
sin_out.append(sin)
|
sin_out.append(sin)
|
||||||
|
|
||||||
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
|
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
|
||||||
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
|
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
|
||||||
|
|
||||||
|
# Squeeze the batch dim if the original input was unbatched
|
||||||
|
if was_unbatched:
|
||||||
|
freqs_cos = freqs_cos.squeeze(0)
|
||||||
|
freqs_sin = freqs_sin.squeeze(0)
|
||||||
|
|
||||||
return freqs_cos, freqs_sin
|
return freqs_cos, freqs_sin
|
||||||
|
|
||||||
|
|
||||||
@@ -685,18 +698,17 @@ class FluxTransformer2DModel(
|
|||||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||||
|
|
||||||
if txt_ids.ndim == 3:
|
if txt_ids.ndim == 3:
|
||||||
logger.warning(
|
# logger.warning(
|
||||||
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
# "Passing `txt_ids` 3d torch.Tensor is deprecated."
|
||||||
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
# "Please remove the batch dimension and pass it as a 2d torch Tensor"
|
||||||
)
|
# )
|
||||||
txt_ids = txt_ids[0]
|
txt_ids = txt_ids[0]
|
||||||
if img_ids.ndim == 3:
|
if img_ids.ndim == 3:
|
||||||
logger.warning(
|
# logger.warning(
|
||||||
"Passing `img_ids` 3d torch.Tensor is deprecated."
|
# "Passing `img_ids` 3d torch.Tensor is deprecated."
|
||||||
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
# "Please remove the batch dimension and pass it as a 2d torch Tensor"
|
||||||
)
|
# )
|
||||||
img_ids = img_ids[0]
|
img_ids = img_ids[0]
|
||||||
|
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||||
image_rotary_emb = self.pos_embed(ids)
|
image_rotary_emb = self.pos_embed(ids)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user