Compare commits

...

4 Commits

Author SHA1 Message Date
Aryan
462d2f4d00 Merge branch 'main' into refactor-rope-and-sincos 2024-10-16 17:20:05 +05:30
Aryan
4d60d144cf Merge branch 'main' into refactor-rope-and-sincos 2024-10-16 02:13:12 +05:30
Aryan
9613541142 Merge branch 'main' into refactor-rope-and-sincos 2024-10-12 02:12:22 +05:30
Aryan
fdc6fd7bd6 update 2024-10-11 22:21:42 +02:00

View File

@@ -78,6 +78,53 @@ def get_timestep_embedding(
return emb
def aryan_get_3d_sincos_pos_embed(
embed_dim: int,
spatial_size: Union[int, Tuple[int, int]],
temporal_size: int,
spatial_interpolation_scale: float = 1.0,
temporal_interpolation_scale: float = 1.0,
device: Optional[torch.device] = None,
) -> np.ndarray:
r"""
Args:
embed_dim (`int`):
spatial_size (`int` or `Tuple[int, int]`):
temporal_size (`int`):
spatial_interpolation_scale (`float`, defaults to 1.0):
temporal_interpolation_scale (`float`, defaults to 1.0):
"""
if embed_dim % 4 != 0:
raise ValueError("`embed_dim` must be divisible by 4")
if isinstance(spatial_size, int):
spatial_size = (spatial_size, spatial_size)
embed_dim_spatial = 3 * embed_dim // 4
embed_dim_temporal = embed_dim // 4
# 1. Spatial
grid_h = torch.arange(spatial_size[1], device=device, dtype=torch.float32) / spatial_interpolation_scale
grid_w = torch.arange(spatial_size[0], device=device, dtype=torch.float32) / spatial_interpolation_scale
grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here, w goes first
grid = torch.stack(grid).reshape(2, 1, spatial_size[1], spatial_size[0])
pos_embed_spatial = aryan_get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid)
# 2. Temporal
grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32) / temporal_interpolation_scale
pos_embed_temporal = aryan_get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t)
# 3. Concat
# [T, H*W, D // 4 * 3]
pos_embed_spatial = pos_embed_spatial.unsqueeze(0).repeat(temporal_size, 1, 1)
# [T, H*W, D // 4]
pos_embed_temporal = pos_embed_temporal.unsqueeze(1).repeat(1, spatial_size[0] * spatial_size[1], 1)
# [T, H*W, D]
pos_embed = torch.cat([pos_embed_temporal, pos_embed_spatial], dim=2)
return pos_embed
def get_3d_sincos_pos_embed(
embed_dim: int,
spatial_size: Union[int, Tuple[int, int]],
@@ -125,6 +172,29 @@ def get_3d_sincos_pos_embed(
return pos_embed
def aryan_get_2d_sincos_pos_embed(
embed_dim: int, grid_size: Union[int, Tuple[int, int]], cls_token: bool = False, extra_tokens: int = 0, interpolation_scale: float = 1.0, base_size: int = 16, device: Optional[torch.device] = None,
):
"""
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
if isinstance(grid_size, int):
grid_size = (grid_size, grid_size)
grid_h = torch.arange(grid_size[0], device=device, dtype=torch.float32) / (grid_size[0] / base_size) / interpolation_scale
grid_w = torch.arange(grid_size[1], device=device, dtype=torch.float32) / (grid_size[1] / base_size) / interpolation_scale
grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here, w goes first
grid = torch.stack(grid).reshape(2, 1, grid_size[1], grid_size[0])
pos_embed = aryan_get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = torch.cat([pos_embed.new_zeros(extra_tokens, embed_dim), pos_embed])
return pos_embed
def get_2d_sincos_pos_embed(
embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
):
@@ -147,6 +217,19 @@ def get_2d_sincos_pos_embed(
return pos_embed
def aryan_get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor:
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
# use half of dimensions to encode grid_h
emb_h = aryan_get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
# use half of dimensions to encode grid_w
emb_w = aryan_get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D)
return emb
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
@@ -159,6 +242,27 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
return emb
def aryan_get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor:
"""
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
"""
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1)
out = pos.reshape(-1)[:, None] * omega[None, :] # (M, D/2), outer product
emb_sin = torch.sin(out)
emb_cos = torch.cos(out)
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
@@ -180,6 +284,370 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
return emb
def aryan_get_3d_rotary_pos_embed(
embed_dim: int, crops_coords: Tuple[int, int], grid_size: Tuple[int, int], temporal_size: int, theta: int = 10000, use_real: bool = True, device: Optional[torch.device] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
RoPE for video tokens with 3D structure.
Args:
embed_dim: (`int`):
The embedding dimension size, corresponding to hidden_size_head.
crops_coords (`Tuple[int, int]`):
The top-left and bottom-right coordinates of the crop.
grid_size (`Tuple[int, int]`):
The grid size of the spatial positional embedding (height, width).
temporal_size (`int`):
The size of the temporal dimension.
theta (`float`):
Scaling factor for frequency computation.
Returns:
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
"""
if use_real is not True:
raise ValueError("`use_real = False` is not currently supported for get_3d_rotary_pos_embed")
start, stop = crops_coords
grid_size_h, grid_size_w = grid_size
grid_h = torch.linspace(start[0], start[0] + (stop[0] - start[0]) * (grid_size[0] - 1) / grid_size[0], grid_size[0], device=device, dtype=torch.float32)
grid_w = torch.linspace(start[1], start[1] + (stop[1] - start[1]) * (grid_size[1] - 1) / grid_size[1], grid_size[1], device=device, dtype=torch.float32)
grid_t = torch.linspace(0, 0 + (temporal_size - 0) * (temporal_size - 1) / temporal_size, temporal_size, device=device, dtype=torch.float32)
# Compute dimensions for each axis
dim_t = embed_dim // 4
dim_h = embed_dim // 8 * 3
dim_w = embed_dim // 8 * 3
# Temporal frequencies
freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
# Spatial frequencies for height and width
freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
# BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
def combine_time_height_width(freqs_t, freqs_h, freqs_w):
freqs_t = freqs_t[:, None, None, :].expand(
-1, grid_size_h, grid_size_w, -1
) # temporal_size, grid_size_h, grid_size_w, dim_t
freqs_h = freqs_h[None, :, None, :].expand(
temporal_size, -1, grid_size_w, -1
) # temporal_size, grid_size_h, grid_size_2, dim_h
freqs_w = freqs_w[None, None, :, :].expand(
temporal_size, grid_size_h, -1, -1
) # temporal_size, grid_size_h, grid_size_2, dim_w
freqs = torch.cat(
[freqs_t, freqs_h, freqs_w], dim=-1
) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
freqs = freqs.view(
temporal_size * grid_size_h * grid_size_w, -1
) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
return freqs
t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
cos = combine_time_height_width(t_cos, h_cos, w_cos)
sin = combine_time_height_width(t_sin, h_sin, w_sin)
return cos, sin
def get_3d_rotary_pos_embed(
embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
RoPE for video tokens with 3D structure.
Args:
embed_dim: (`int`):
The embedding dimension size, corresponding to hidden_size_head.
crops_coords (`Tuple[int]`):
The top-left and bottom-right coordinates of the crop.
grid_size (`Tuple[int]`):
The grid size of the spatial positional embedding (height, width).
temporal_size (`int`):
The size of the temporal dimension.
theta (`float`):
Scaling factor for frequency computation.
Returns:
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
"""
if use_real is not True:
raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
start, stop = crops_coords
grid_size_h, grid_size_w = grid_size
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
# Compute dimensions for each axis
dim_t = embed_dim // 4
dim_h = embed_dim // 8 * 3
dim_w = embed_dim // 8 * 3
# Temporal frequencies
freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
# Spatial frequencies for height and width
freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
# BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
def combine_time_height_width(freqs_t, freqs_h, freqs_w):
freqs_t = freqs_t[:, None, None, :].expand(
-1, grid_size_h, grid_size_w, -1
) # temporal_size, grid_size_h, grid_size_w, dim_t
freqs_h = freqs_h[None, :, None, :].expand(
temporal_size, -1, grid_size_w, -1
) # temporal_size, grid_size_h, grid_size_2, dim_h
freqs_w = freqs_w[None, None, :, :].expand(
temporal_size, grid_size_h, -1, -1
) # temporal_size, grid_size_h, grid_size_2, dim_w
freqs = torch.cat(
[freqs_t, freqs_h, freqs_w], dim=-1
) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
freqs = freqs.view(
temporal_size * grid_size_h * grid_size_w, -1
) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
return freqs
t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
cos = combine_time_height_width(t_cos, h_cos, w_cos)
sin = combine_time_height_width(t_sin, h_sin, w_sin)
return cos, sin
def aryan_get_2d_rotary_pos_embed(embed_dim: int, crops_coords: Tuple[int, int], grid_size: Tuple[int, int], use_real: bool = True) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
RoPE for image tokens with 2d structure.
Args:
embed_dim: (`int`):
The embedding dimension size
crops_coords (`Tuple[int, int]`)
The top-left and bottom-right coordinates of the crop.
grid_size (`Tuple[int]`):
The grid size of the positional embedding.
use_real (`bool`):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
Returns:
`torch.Tensor`: positional embedding with shape `(grid_size * grid_size, embed_dim/2)`.
"""
start, stop = crops_coords
grid_h = torch.linspace(start[0], start[0] + (stop[0] - start[0]) * (grid_size[0] - 1) / grid_size[0], grid_size[0], dtype=torch.float32)
grid_w = torch.linspace(start[1], start[1] + (stop[1] - start[1]) * (grid_size[1] - 1) / grid_size[1], grid_size[1], dtype=torch.float32)
grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here, w goes first
grid = torch.stack(grid)
grid = grid.reshape(2, 1, *grid.shape[1:])
pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
return pos_embed
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
"""
RoPE for image tokens with 2d structure.
Args:
embed_dim: (`int`):
The embedding dimension size
crops_coords (`Tuple[int]`)
The top-left and bottom-right coordinates of the crop.
grid_size (`Tuple[int]`):
The grid size of the positional embedding.
use_real (`bool`):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
Returns:
`torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
"""
start, stop = crops_coords
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0) # [2, W, H]
grid = grid.reshape([2, 1, *grid.shape[1:]])
pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
return pos_embed
def aryan_get_2d_rotary_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor, use_real: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
assert embed_dim % 4 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_rotary_pos_embed(
embed_dim // 2, grid[0].reshape(-1), use_real=use_real
) # (H*W, D/2) if use_real else (H*W, D/4)
emb_w = get_1d_rotary_pos_embed(
embed_dim // 2, grid[1].reshape(-1), use_real=use_real
) # (H*W, D/2) if use_real else (H*W, D/4)
if use_real:
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D)
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D)
return cos, sin
else:
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
return emb
def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
assert embed_dim % 4 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_rotary_pos_embed(
embed_dim // 2, grid[0].reshape(-1), use_real=use_real
) # (H*W, D/2) if use_real else (H*W, D/4)
emb_w = get_1d_rotary_pos_embed(
embed_dim // 2, grid[1].reshape(-1), use_real=use_real
) # (H*W, D/2) if use_real else (H*W, D/4)
if use_real:
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D)
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D)
return cos, sin
else:
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
return emb
def get_2d_rotary_pos_embed_lumina(embed_dim: int, len_h: int, len_w: int, linear_factor: float = 1.0, ntk_factor: float = 1.0) -> torch.Tensor:
assert embed_dim % 4 == 0
emb_h = get_1d_rotary_pos_embed(
embed_dim // 2, len_h, linear_factor=linear_factor, ntk_factor=ntk_factor
) # (H, D/4)
emb_w = get_1d_rotary_pos_embed(
embed_dim // 2, len_w, linear_factor=linear_factor, ntk_factor=ntk_factor
) # (W, D/4)
emb_h = emb_h.view(len_h, 1, embed_dim // 4, 1).repeat(1, len_w, 1, 1) # (H, W, D/4, 1)
emb_w = emb_w.view(1, len_w, embed_dim // 4, 1).repeat(len_h, 1, 1, 1) # (H, W, D/4, 1)
emb = torch.cat([emb_h, emb_w], dim=-1).flatten(2) # (H, W, D/2)
return emb
def get_1d_rotary_pos_embed(
dim: int,
pos: Union[torch.Tensor, np.ndarray, int],
theta: float = 10000.0,
use_real: bool = False,
linear_factor: float = 1.0,
ntk_factor: float = 1.0,
repeat_interleave_real: bool = True,
freqs_dtype: torch.dtype = torch.float32, # torch.float32, torch.float64 (flux)
):
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
data type.
Args:
dim (`int`): Dimension of the frequency tensor.
pos (`torch.Tensor` or `np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
theta (`float`, *optional*, defaults to 10000.0):
Scaling factor for frequency computation. Defaults to 10000.0.
use_real (`bool`, *optional*):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
linear_factor (`float`, *optional*, defaults to 1.0):
Scaling factor for the context extrapolation. Defaults to 1.0.
ntk_factor (`float`, *optional*, defaults to 1.0):
Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
Otherwise, they are concateanted with themselves.
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
the dtype of the frequency tensor.
Returns:
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
"""
assert dim % 2 == 0
if isinstance(pos, int):
pos = torch.arange(pos)
if isinstance(pos, np.ndarray):
pos = torch.from_numpy(pos) # type: ignore # [S]
theta = theta * ntk_factor
freqs = (
1.0
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
/ linear_factor
) # [D/2]
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
if use_real and repeat_interleave_real:
# flux, hunyuan-dit, cogvideox
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
return freqs_cos, freqs_sin
elif use_real:
# stable audio
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
return freqs_cos, freqs_sin
else:
# lumina
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
return freqs_cis
def apply_rotary_emb(
x: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
use_real: bool = True,
use_real_unbind_dim: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
tensors contain rotary embeddings and are returned as real tensors.
Args:
x (`torch.Tensor`):
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
if use_real:
cos, sin = freqs_cis # [S, D]
cos = cos[None, None]
sin = sin[None, None]
cos, sin = cos.to(x.device), sin.to(x.device)
if use_real_unbind_dim == -1:
# Used for flux, cogvideox, hunyuan-dit
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
elif use_real_unbind_dim == -2:
# Used for Stable Audio
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
else:
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
else:
# used for lumina
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
return x_out.type_as(x)
class PatchEmbed(nn.Module):
"""2D Image to Patch Embedding with support for SD3 cropping."""
@@ -496,253 +964,6 @@ class CogView3PlusPatchEmbed(nn.Module):
return (hidden_states + pos_embed).to(hidden_states.dtype)
def get_3d_rotary_pos_embed(
embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
RoPE for video tokens with 3D structure.
Args:
embed_dim: (`int`):
The embedding dimension size, corresponding to hidden_size_head.
crops_coords (`Tuple[int]`):
The top-left and bottom-right coordinates of the crop.
grid_size (`Tuple[int]`):
The grid size of the spatial positional embedding (height, width).
temporal_size (`int`):
The size of the temporal dimension.
theta (`float`):
Scaling factor for frequency computation.
Returns:
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
"""
if use_real is not True:
raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
start, stop = crops_coords
grid_size_h, grid_size_w = grid_size
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
# Compute dimensions for each axis
dim_t = embed_dim // 4
dim_h = embed_dim // 8 * 3
dim_w = embed_dim // 8 * 3
# Temporal frequencies
freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
# Spatial frequencies for height and width
freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
# BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
def combine_time_height_width(freqs_t, freqs_h, freqs_w):
freqs_t = freqs_t[:, None, None, :].expand(
-1, grid_size_h, grid_size_w, -1
) # temporal_size, grid_size_h, grid_size_w, dim_t
freqs_h = freqs_h[None, :, None, :].expand(
temporal_size, -1, grid_size_w, -1
) # temporal_size, grid_size_h, grid_size_2, dim_h
freqs_w = freqs_w[None, None, :, :].expand(
temporal_size, grid_size_h, -1, -1
) # temporal_size, grid_size_h, grid_size_2, dim_w
freqs = torch.cat(
[freqs_t, freqs_h, freqs_w], dim=-1
) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
freqs = freqs.view(
temporal_size * grid_size_h * grid_size_w, -1
) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
return freqs
t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
cos = combine_time_height_width(t_cos, h_cos, w_cos)
sin = combine_time_height_width(t_sin, h_sin, w_sin)
return cos, sin
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
"""
RoPE for image tokens with 2d structure.
Args:
embed_dim: (`int`):
The embedding dimension size
crops_coords (`Tuple[int]`)
The top-left and bottom-right coordinates of the crop.
grid_size (`Tuple[int]`):
The grid size of the positional embedding.
use_real (`bool`):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
Returns:
`torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
"""
start, stop = crops_coords
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0) # [2, W, H]
grid = grid.reshape([2, 1, *grid.shape[1:]])
pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
return pos_embed
def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
assert embed_dim % 4 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_rotary_pos_embed(
embed_dim // 2, grid[0].reshape(-1), use_real=use_real
) # (H*W, D/2) if use_real else (H*W, D/4)
emb_w = get_1d_rotary_pos_embed(
embed_dim // 2, grid[1].reshape(-1), use_real=use_real
) # (H*W, D/2) if use_real else (H*W, D/4)
if use_real:
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D)
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D)
return cos, sin
else:
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
return emb
def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, ntk_factor=1.0):
assert embed_dim % 4 == 0
emb_h = get_1d_rotary_pos_embed(
embed_dim // 2, len_h, linear_factor=linear_factor, ntk_factor=ntk_factor
) # (H, D/4)
emb_w = get_1d_rotary_pos_embed(
embed_dim // 2, len_w, linear_factor=linear_factor, ntk_factor=ntk_factor
) # (W, D/4)
emb_h = emb_h.view(len_h, 1, embed_dim // 4, 1).repeat(1, len_w, 1, 1) # (H, W, D/4, 1)
emb_w = emb_w.view(1, len_w, embed_dim // 4, 1).repeat(len_h, 1, 1, 1) # (H, W, D/4, 1)
emb = torch.cat([emb_h, emb_w], dim=-1).flatten(2) # (H, W, D/2)
return emb
def get_1d_rotary_pos_embed(
dim: int,
pos: Union[np.ndarray, int],
theta: float = 10000.0,
use_real=False,
linear_factor=1.0,
ntk_factor=1.0,
repeat_interleave_real=True,
freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
):
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
data type.
Args:
dim (`int`): Dimension of the frequency tensor.
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
theta (`float`, *optional*, defaults to 10000.0):
Scaling factor for frequency computation. Defaults to 10000.0.
use_real (`bool`, *optional*):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
linear_factor (`float`, *optional*, defaults to 1.0):
Scaling factor for the context extrapolation. Defaults to 1.0.
ntk_factor (`float`, *optional*, defaults to 1.0):
Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
Otherwise, they are concateanted with themselves.
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
the dtype of the frequency tensor.
Returns:
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
"""
assert dim % 2 == 0
if isinstance(pos, int):
pos = torch.arange(pos)
if isinstance(pos, np.ndarray):
pos = torch.from_numpy(pos) # type: ignore # [S]
theta = theta * ntk_factor
freqs = (
1.0
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
/ linear_factor
) # [D/2]
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
if use_real and repeat_interleave_real:
# flux, hunyuan-dit, cogvideox
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
return freqs_cos, freqs_sin
elif use_real:
# stable audio
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
return freqs_cos, freqs_sin
else:
# lumina
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
return freqs_cis
def apply_rotary_emb(
x: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
use_real: bool = True,
use_real_unbind_dim: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
tensors contain rotary embeddings and are returned as real tensors.
Args:
x (`torch.Tensor`):
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
if use_real:
cos, sin = freqs_cis # [S, D]
cos = cos[None, None]
sin = sin[None, None]
cos, sin = cos.to(x.device), sin.to(x.device)
if use_real_unbind_dim == -1:
# Used for flux, cogvideox, hunyuan-dit
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
elif use_real_unbind_dim == -2:
# Used for Stable Audio
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
else:
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
else:
# used for lumina
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
return x_out.type_as(x)
class FluxPosEmbed(nn.Module):
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
def __init__(self, theta: int, axes_dim: List[int]):