mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-11 23:14:37 +08:00
Compare commits
4 Commits
custom-rev
...
refactor-r
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
462d2f4d00 | ||
|
|
4d60d144cf | ||
|
|
9613541142 | ||
|
|
fdc6fd7bd6 |
@@ -78,6 +78,53 @@ def get_timestep_embedding(
|
||||
return emb
|
||||
|
||||
|
||||
def aryan_get_3d_sincos_pos_embed(
|
||||
embed_dim: int,
|
||||
spatial_size: Union[int, Tuple[int, int]],
|
||||
temporal_size: int,
|
||||
spatial_interpolation_scale: float = 1.0,
|
||||
temporal_interpolation_scale: float = 1.0,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> np.ndarray:
|
||||
r"""
|
||||
Args:
|
||||
embed_dim (`int`):
|
||||
spatial_size (`int` or `Tuple[int, int]`):
|
||||
temporal_size (`int`):
|
||||
spatial_interpolation_scale (`float`, defaults to 1.0):
|
||||
temporal_interpolation_scale (`float`, defaults to 1.0):
|
||||
"""
|
||||
if embed_dim % 4 != 0:
|
||||
raise ValueError("`embed_dim` must be divisible by 4")
|
||||
if isinstance(spatial_size, int):
|
||||
spatial_size = (spatial_size, spatial_size)
|
||||
|
||||
embed_dim_spatial = 3 * embed_dim // 4
|
||||
embed_dim_temporal = embed_dim // 4
|
||||
|
||||
# 1. Spatial
|
||||
grid_h = torch.arange(spatial_size[1], device=device, dtype=torch.float32) / spatial_interpolation_scale
|
||||
grid_w = torch.arange(spatial_size[0], device=device, dtype=torch.float32) / spatial_interpolation_scale
|
||||
grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here, w goes first
|
||||
grid = torch.stack(grid).reshape(2, 1, spatial_size[1], spatial_size[0])
|
||||
pos_embed_spatial = aryan_get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid)
|
||||
|
||||
# 2. Temporal
|
||||
grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32) / temporal_interpolation_scale
|
||||
pos_embed_temporal = aryan_get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t)
|
||||
|
||||
# 3. Concat
|
||||
# [T, H*W, D // 4 * 3]
|
||||
pos_embed_spatial = pos_embed_spatial.unsqueeze(0).repeat(temporal_size, 1, 1)
|
||||
|
||||
# [T, H*W, D // 4]
|
||||
pos_embed_temporal = pos_embed_temporal.unsqueeze(1).repeat(1, spatial_size[0] * spatial_size[1], 1)
|
||||
|
||||
# [T, H*W, D]
|
||||
pos_embed = torch.cat([pos_embed_temporal, pos_embed_spatial], dim=2)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_3d_sincos_pos_embed(
|
||||
embed_dim: int,
|
||||
spatial_size: Union[int, Tuple[int, int]],
|
||||
@@ -125,6 +172,29 @@ def get_3d_sincos_pos_embed(
|
||||
return pos_embed
|
||||
|
||||
|
||||
def aryan_get_2d_sincos_pos_embed(
|
||||
embed_dim: int, grid_size: Union[int, Tuple[int, int]], cls_token: bool = False, extra_tokens: int = 0, interpolation_scale: float = 1.0, base_size: int = 16, device: Optional[torch.device] = None,
|
||||
):
|
||||
"""
|
||||
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
|
||||
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||
"""
|
||||
if isinstance(grid_size, int):
|
||||
grid_size = (grid_size, grid_size)
|
||||
|
||||
grid_h = torch.arange(grid_size[0], device=device, dtype=torch.float32) / (grid_size[0] / base_size) / interpolation_scale
|
||||
grid_w = torch.arange(grid_size[1], device=device, dtype=torch.float32) / (grid_size[1] / base_size) / interpolation_scale
|
||||
grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here, w goes first
|
||||
grid = torch.stack(grid).reshape(2, 1, grid_size[1], grid_size[0])
|
||||
|
||||
pos_embed = aryan_get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||
|
||||
if cls_token and extra_tokens > 0:
|
||||
pos_embed = torch.cat([pos_embed.new_zeros(extra_tokens, embed_dim), pos_embed])
|
||||
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed(
|
||||
embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
|
||||
):
|
||||
@@ -147,6 +217,19 @@ def get_2d_sincos_pos_embed(
|
||||
return pos_embed
|
||||
|
||||
|
||||
def aryan_get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor:
|
||||
if embed_dim % 2 != 0:
|
||||
raise ValueError("embed_dim must be divisible by 2")
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = aryan_get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||
# use half of dimensions to encode grid_w
|
||||
emb_w = aryan_get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||
|
||||
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||
if embed_dim % 2 != 0:
|
||||
raise ValueError("embed_dim must be divisible by 2")
|
||||
@@ -159,6 +242,27 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||
return emb
|
||||
|
||||
|
||||
def aryan_get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
|
||||
"""
|
||||
if embed_dim % 2 != 0:
|
||||
raise ValueError("embed_dim must be divisible by 2")
|
||||
|
||||
omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64)
|
||||
omega /= embed_dim / 2.0
|
||||
omega = 1.0 / 10000**omega # (D/2,)
|
||||
|
||||
pos = pos.reshape(-1)
|
||||
out = pos.reshape(-1)[:, None] * omega[None, :] # (M, D/2), outer product
|
||||
|
||||
emb_sin = torch.sin(out)
|
||||
emb_cos = torch.cos(out)
|
||||
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
||||
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
"""
|
||||
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
|
||||
@@ -180,6 +284,370 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
return emb
|
||||
|
||||
|
||||
def aryan_get_3d_rotary_pos_embed(
|
||||
embed_dim: int, crops_coords: Tuple[int, int], grid_size: Tuple[int, int], temporal_size: int, theta: int = 10000, use_real: bool = True, device: Optional[torch.device] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
RoPE for video tokens with 3D structure.
|
||||
|
||||
Args:
|
||||
embed_dim: (`int`):
|
||||
The embedding dimension size, corresponding to hidden_size_head.
|
||||
crops_coords (`Tuple[int, int]`):
|
||||
The top-left and bottom-right coordinates of the crop.
|
||||
grid_size (`Tuple[int, int]`):
|
||||
The grid size of the spatial positional embedding (height, width).
|
||||
temporal_size (`int`):
|
||||
The size of the temporal dimension.
|
||||
theta (`float`):
|
||||
Scaling factor for frequency computation.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
|
||||
"""
|
||||
if use_real is not True:
|
||||
raise ValueError("`use_real = False` is not currently supported for get_3d_rotary_pos_embed")
|
||||
|
||||
start, stop = crops_coords
|
||||
grid_size_h, grid_size_w = grid_size
|
||||
grid_h = torch.linspace(start[0], start[0] + (stop[0] - start[0]) * (grid_size[0] - 1) / grid_size[0], grid_size[0], device=device, dtype=torch.float32)
|
||||
grid_w = torch.linspace(start[1], start[1] + (stop[1] - start[1]) * (grid_size[1] - 1) / grid_size[1], grid_size[1], device=device, dtype=torch.float32)
|
||||
grid_t = torch.linspace(0, 0 + (temporal_size - 0) * (temporal_size - 1) / temporal_size, temporal_size, device=device, dtype=torch.float32)
|
||||
|
||||
# Compute dimensions for each axis
|
||||
dim_t = embed_dim // 4
|
||||
dim_h = embed_dim // 8 * 3
|
||||
dim_w = embed_dim // 8 * 3
|
||||
|
||||
# Temporal frequencies
|
||||
freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
|
||||
# Spatial frequencies for height and width
|
||||
freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
|
||||
freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
|
||||
|
||||
# BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
|
||||
def combine_time_height_width(freqs_t, freqs_h, freqs_w):
|
||||
freqs_t = freqs_t[:, None, None, :].expand(
|
||||
-1, grid_size_h, grid_size_w, -1
|
||||
) # temporal_size, grid_size_h, grid_size_w, dim_t
|
||||
freqs_h = freqs_h[None, :, None, :].expand(
|
||||
temporal_size, -1, grid_size_w, -1
|
||||
) # temporal_size, grid_size_h, grid_size_2, dim_h
|
||||
freqs_w = freqs_w[None, None, :, :].expand(
|
||||
temporal_size, grid_size_h, -1, -1
|
||||
) # temporal_size, grid_size_h, grid_size_2, dim_w
|
||||
|
||||
freqs = torch.cat(
|
||||
[freqs_t, freqs_h, freqs_w], dim=-1
|
||||
) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
|
||||
freqs = freqs.view(
|
||||
temporal_size * grid_size_h * grid_size_w, -1
|
||||
) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
|
||||
return freqs
|
||||
|
||||
t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
|
||||
h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
|
||||
w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
|
||||
cos = combine_time_height_width(t_cos, h_cos, w_cos)
|
||||
sin = combine_time_height_width(t_sin, h_sin, w_sin)
|
||||
return cos, sin
|
||||
|
||||
|
||||
def get_3d_rotary_pos_embed(
|
||||
embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
RoPE for video tokens with 3D structure.
|
||||
|
||||
Args:
|
||||
embed_dim: (`int`):
|
||||
The embedding dimension size, corresponding to hidden_size_head.
|
||||
crops_coords (`Tuple[int]`):
|
||||
The top-left and bottom-right coordinates of the crop.
|
||||
grid_size (`Tuple[int]`):
|
||||
The grid size of the spatial positional embedding (height, width).
|
||||
temporal_size (`int`):
|
||||
The size of the temporal dimension.
|
||||
theta (`float`):
|
||||
Scaling factor for frequency computation.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
|
||||
"""
|
||||
if use_real is not True:
|
||||
raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
|
||||
start, stop = crops_coords
|
||||
grid_size_h, grid_size_w = grid_size
|
||||
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
|
||||
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
|
||||
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
|
||||
|
||||
# Compute dimensions for each axis
|
||||
dim_t = embed_dim // 4
|
||||
dim_h = embed_dim // 8 * 3
|
||||
dim_w = embed_dim // 8 * 3
|
||||
|
||||
# Temporal frequencies
|
||||
freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
|
||||
# Spatial frequencies for height and width
|
||||
freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
|
||||
freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
|
||||
|
||||
# BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
|
||||
def combine_time_height_width(freqs_t, freqs_h, freqs_w):
|
||||
freqs_t = freqs_t[:, None, None, :].expand(
|
||||
-1, grid_size_h, grid_size_w, -1
|
||||
) # temporal_size, grid_size_h, grid_size_w, dim_t
|
||||
freqs_h = freqs_h[None, :, None, :].expand(
|
||||
temporal_size, -1, grid_size_w, -1
|
||||
) # temporal_size, grid_size_h, grid_size_2, dim_h
|
||||
freqs_w = freqs_w[None, None, :, :].expand(
|
||||
temporal_size, grid_size_h, -1, -1
|
||||
) # temporal_size, grid_size_h, grid_size_2, dim_w
|
||||
|
||||
freqs = torch.cat(
|
||||
[freqs_t, freqs_h, freqs_w], dim=-1
|
||||
) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
|
||||
freqs = freqs.view(
|
||||
temporal_size * grid_size_h * grid_size_w, -1
|
||||
) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
|
||||
return freqs
|
||||
|
||||
t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
|
||||
h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
|
||||
w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
|
||||
cos = combine_time_height_width(t_cos, h_cos, w_cos)
|
||||
sin = combine_time_height_width(t_sin, h_sin, w_sin)
|
||||
return cos, sin
|
||||
|
||||
|
||||
def aryan_get_2d_rotary_pos_embed(embed_dim: int, crops_coords: Tuple[int, int], grid_size: Tuple[int, int], use_real: bool = True) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
RoPE for image tokens with 2d structure.
|
||||
|
||||
Args:
|
||||
embed_dim: (`int`):
|
||||
The embedding dimension size
|
||||
crops_coords (`Tuple[int, int]`)
|
||||
The top-left and bottom-right coordinates of the crop.
|
||||
grid_size (`Tuple[int]`):
|
||||
The grid size of the positional embedding.
|
||||
use_real (`bool`):
|
||||
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: positional embedding with shape `(grid_size * grid_size, embed_dim/2)`.
|
||||
"""
|
||||
start, stop = crops_coords
|
||||
grid_h = torch.linspace(start[0], start[0] + (stop[0] - start[0]) * (grid_size[0] - 1) / grid_size[0], grid_size[0], dtype=torch.float32)
|
||||
grid_w = torch.linspace(start[1], start[1] + (stop[1] - start[1]) * (grid_size[1] - 1) / grid_size[1], grid_size[1], dtype=torch.float32)
|
||||
grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here, w goes first
|
||||
grid = torch.stack(grid)
|
||||
grid = grid.reshape(2, 1, *grid.shape[1:])
|
||||
|
||||
pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
|
||||
"""
|
||||
RoPE for image tokens with 2d structure.
|
||||
|
||||
Args:
|
||||
embed_dim: (`int`):
|
||||
The embedding dimension size
|
||||
crops_coords (`Tuple[int]`)
|
||||
The top-left and bottom-right coordinates of the crop.
|
||||
grid_size (`Tuple[int]`):
|
||||
The grid size of the positional embedding.
|
||||
use_real (`bool`):
|
||||
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
|
||||
"""
|
||||
start, stop = crops_coords
|
||||
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
|
||||
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
|
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
grid = np.stack(grid, axis=0) # [2, W, H]
|
||||
|
||||
grid = grid.reshape([2, 1, *grid.shape[1:]])
|
||||
pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def aryan_get_2d_rotary_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor, use_real: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert embed_dim % 4 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_rotary_pos_embed(
|
||||
embed_dim // 2, grid[0].reshape(-1), use_real=use_real
|
||||
) # (H*W, D/2) if use_real else (H*W, D/4)
|
||||
emb_w = get_1d_rotary_pos_embed(
|
||||
embed_dim // 2, grid[1].reshape(-1), use_real=use_real
|
||||
) # (H*W, D/2) if use_real else (H*W, D/4)
|
||||
|
||||
if use_real:
|
||||
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D)
|
||||
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D)
|
||||
return cos, sin
|
||||
else:
|
||||
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
|
||||
return emb
|
||||
|
||||
|
||||
def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
|
||||
assert embed_dim % 4 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_rotary_pos_embed(
|
||||
embed_dim // 2, grid[0].reshape(-1), use_real=use_real
|
||||
) # (H*W, D/2) if use_real else (H*W, D/4)
|
||||
emb_w = get_1d_rotary_pos_embed(
|
||||
embed_dim // 2, grid[1].reshape(-1), use_real=use_real
|
||||
) # (H*W, D/2) if use_real else (H*W, D/4)
|
||||
|
||||
if use_real:
|
||||
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D)
|
||||
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D)
|
||||
return cos, sin
|
||||
else:
|
||||
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
|
||||
return emb
|
||||
|
||||
|
||||
def get_2d_rotary_pos_embed_lumina(embed_dim: int, len_h: int, len_w: int, linear_factor: float = 1.0, ntk_factor: float = 1.0) -> torch.Tensor:
|
||||
assert embed_dim % 4 == 0
|
||||
|
||||
emb_h = get_1d_rotary_pos_embed(
|
||||
embed_dim // 2, len_h, linear_factor=linear_factor, ntk_factor=ntk_factor
|
||||
) # (H, D/4)
|
||||
emb_w = get_1d_rotary_pos_embed(
|
||||
embed_dim // 2, len_w, linear_factor=linear_factor, ntk_factor=ntk_factor
|
||||
) # (W, D/4)
|
||||
emb_h = emb_h.view(len_h, 1, embed_dim // 4, 1).repeat(1, len_w, 1, 1) # (H, W, D/4, 1)
|
||||
emb_w = emb_w.view(1, len_w, embed_dim // 4, 1).repeat(len_h, 1, 1, 1) # (H, W, D/4, 1)
|
||||
|
||||
emb = torch.cat([emb_h, emb_w], dim=-1).flatten(2) # (H, W, D/2)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_rotary_pos_embed(
|
||||
dim: int,
|
||||
pos: Union[torch.Tensor, np.ndarray, int],
|
||||
theta: float = 10000.0,
|
||||
use_real: bool = False,
|
||||
linear_factor: float = 1.0,
|
||||
ntk_factor: float = 1.0,
|
||||
repeat_interleave_real: bool = True,
|
||||
freqs_dtype: torch.dtype = torch.float32, # torch.float32, torch.float64 (flux)
|
||||
):
|
||||
"""
|
||||
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
||||
|
||||
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
|
||||
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
|
||||
data type.
|
||||
|
||||
Args:
|
||||
dim (`int`): Dimension of the frequency tensor.
|
||||
pos (`torch.Tensor` or `np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
|
||||
theta (`float`, *optional*, defaults to 10000.0):
|
||||
Scaling factor for frequency computation. Defaults to 10000.0.
|
||||
use_real (`bool`, *optional*):
|
||||
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
||||
linear_factor (`float`, *optional*, defaults to 1.0):
|
||||
Scaling factor for the context extrapolation. Defaults to 1.0.
|
||||
ntk_factor (`float`, *optional*, defaults to 1.0):
|
||||
Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
|
||||
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
|
||||
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
|
||||
Otherwise, they are concateanted with themselves.
|
||||
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
|
||||
the dtype of the frequency tensor.
|
||||
Returns:
|
||||
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
||||
"""
|
||||
assert dim % 2 == 0
|
||||
|
||||
if isinstance(pos, int):
|
||||
pos = torch.arange(pos)
|
||||
if isinstance(pos, np.ndarray):
|
||||
pos = torch.from_numpy(pos) # type: ignore # [S]
|
||||
|
||||
theta = theta * ntk_factor
|
||||
freqs = (
|
||||
1.0
|
||||
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
|
||||
/ linear_factor
|
||||
) # [D/2]
|
||||
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
|
||||
if use_real and repeat_interleave_real:
|
||||
# flux, hunyuan-dit, cogvideox
|
||||
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
|
||||
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
elif use_real:
|
||||
# stable audio
|
||||
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
|
||||
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
else:
|
||||
# lumina
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
x: torch.Tensor,
|
||||
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
||||
use_real: bool = True,
|
||||
use_real_unbind_dim: int = -1,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
||||
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
||||
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
||||
tensors contain rotary embeddings and are returned as real tensors.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`):
|
||||
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
||||
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
||||
"""
|
||||
if use_real:
|
||||
cos, sin = freqs_cis # [S, D]
|
||||
cos = cos[None, None]
|
||||
sin = sin[None, None]
|
||||
cos, sin = cos.to(x.device), sin.to(x.device)
|
||||
|
||||
if use_real_unbind_dim == -1:
|
||||
# Used for flux, cogvideox, hunyuan-dit
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
elif use_real_unbind_dim == -2:
|
||||
# Used for Stable Audio
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
|
||||
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
||||
else:
|
||||
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
||||
|
||||
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
||||
|
||||
return out
|
||||
else:
|
||||
# used for lumina
|
||||
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
||||
freqs_cis = freqs_cis.unsqueeze(2)
|
||||
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
||||
|
||||
return x_out.type_as(x)
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""2D Image to Patch Embedding with support for SD3 cropping."""
|
||||
|
||||
@@ -496,253 +964,6 @@ class CogView3PlusPatchEmbed(nn.Module):
|
||||
return (hidden_states + pos_embed).to(hidden_states.dtype)
|
||||
|
||||
|
||||
def get_3d_rotary_pos_embed(
|
||||
embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
RoPE for video tokens with 3D structure.
|
||||
|
||||
Args:
|
||||
embed_dim: (`int`):
|
||||
The embedding dimension size, corresponding to hidden_size_head.
|
||||
crops_coords (`Tuple[int]`):
|
||||
The top-left and bottom-right coordinates of the crop.
|
||||
grid_size (`Tuple[int]`):
|
||||
The grid size of the spatial positional embedding (height, width).
|
||||
temporal_size (`int`):
|
||||
The size of the temporal dimension.
|
||||
theta (`float`):
|
||||
Scaling factor for frequency computation.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
|
||||
"""
|
||||
if use_real is not True:
|
||||
raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
|
||||
start, stop = crops_coords
|
||||
grid_size_h, grid_size_w = grid_size
|
||||
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
|
||||
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
|
||||
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
|
||||
|
||||
# Compute dimensions for each axis
|
||||
dim_t = embed_dim // 4
|
||||
dim_h = embed_dim // 8 * 3
|
||||
dim_w = embed_dim // 8 * 3
|
||||
|
||||
# Temporal frequencies
|
||||
freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
|
||||
# Spatial frequencies for height and width
|
||||
freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
|
||||
freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
|
||||
|
||||
# BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
|
||||
def combine_time_height_width(freqs_t, freqs_h, freqs_w):
|
||||
freqs_t = freqs_t[:, None, None, :].expand(
|
||||
-1, grid_size_h, grid_size_w, -1
|
||||
) # temporal_size, grid_size_h, grid_size_w, dim_t
|
||||
freqs_h = freqs_h[None, :, None, :].expand(
|
||||
temporal_size, -1, grid_size_w, -1
|
||||
) # temporal_size, grid_size_h, grid_size_2, dim_h
|
||||
freqs_w = freqs_w[None, None, :, :].expand(
|
||||
temporal_size, grid_size_h, -1, -1
|
||||
) # temporal_size, grid_size_h, grid_size_2, dim_w
|
||||
|
||||
freqs = torch.cat(
|
||||
[freqs_t, freqs_h, freqs_w], dim=-1
|
||||
) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
|
||||
freqs = freqs.view(
|
||||
temporal_size * grid_size_h * grid_size_w, -1
|
||||
) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
|
||||
return freqs
|
||||
|
||||
t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
|
||||
h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
|
||||
w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
|
||||
cos = combine_time_height_width(t_cos, h_cos, w_cos)
|
||||
sin = combine_time_height_width(t_sin, h_sin, w_sin)
|
||||
return cos, sin
|
||||
|
||||
|
||||
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
|
||||
"""
|
||||
RoPE for image tokens with 2d structure.
|
||||
|
||||
Args:
|
||||
embed_dim: (`int`):
|
||||
The embedding dimension size
|
||||
crops_coords (`Tuple[int]`)
|
||||
The top-left and bottom-right coordinates of the crop.
|
||||
grid_size (`Tuple[int]`):
|
||||
The grid size of the positional embedding.
|
||||
use_real (`bool`):
|
||||
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
|
||||
"""
|
||||
start, stop = crops_coords
|
||||
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
|
||||
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
|
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
grid = np.stack(grid, axis=0) # [2, W, H]
|
||||
|
||||
grid = grid.reshape([2, 1, *grid.shape[1:]])
|
||||
pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
|
||||
assert embed_dim % 4 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_rotary_pos_embed(
|
||||
embed_dim // 2, grid[0].reshape(-1), use_real=use_real
|
||||
) # (H*W, D/2) if use_real else (H*W, D/4)
|
||||
emb_w = get_1d_rotary_pos_embed(
|
||||
embed_dim // 2, grid[1].reshape(-1), use_real=use_real
|
||||
) # (H*W, D/2) if use_real else (H*W, D/4)
|
||||
|
||||
if use_real:
|
||||
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D)
|
||||
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D)
|
||||
return cos, sin
|
||||
else:
|
||||
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
|
||||
return emb
|
||||
|
||||
|
||||
def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, ntk_factor=1.0):
|
||||
assert embed_dim % 4 == 0
|
||||
|
||||
emb_h = get_1d_rotary_pos_embed(
|
||||
embed_dim // 2, len_h, linear_factor=linear_factor, ntk_factor=ntk_factor
|
||||
) # (H, D/4)
|
||||
emb_w = get_1d_rotary_pos_embed(
|
||||
embed_dim // 2, len_w, linear_factor=linear_factor, ntk_factor=ntk_factor
|
||||
) # (W, D/4)
|
||||
emb_h = emb_h.view(len_h, 1, embed_dim // 4, 1).repeat(1, len_w, 1, 1) # (H, W, D/4, 1)
|
||||
emb_w = emb_w.view(1, len_w, embed_dim // 4, 1).repeat(len_h, 1, 1, 1) # (H, W, D/4, 1)
|
||||
|
||||
emb = torch.cat([emb_h, emb_w], dim=-1).flatten(2) # (H, W, D/2)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_rotary_pos_embed(
|
||||
dim: int,
|
||||
pos: Union[np.ndarray, int],
|
||||
theta: float = 10000.0,
|
||||
use_real=False,
|
||||
linear_factor=1.0,
|
||||
ntk_factor=1.0,
|
||||
repeat_interleave_real=True,
|
||||
freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
|
||||
):
|
||||
"""
|
||||
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
||||
|
||||
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
|
||||
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
|
||||
data type.
|
||||
|
||||
Args:
|
||||
dim (`int`): Dimension of the frequency tensor.
|
||||
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
|
||||
theta (`float`, *optional*, defaults to 10000.0):
|
||||
Scaling factor for frequency computation. Defaults to 10000.0.
|
||||
use_real (`bool`, *optional*):
|
||||
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
||||
linear_factor (`float`, *optional*, defaults to 1.0):
|
||||
Scaling factor for the context extrapolation. Defaults to 1.0.
|
||||
ntk_factor (`float`, *optional*, defaults to 1.0):
|
||||
Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
|
||||
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
|
||||
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
|
||||
Otherwise, they are concateanted with themselves.
|
||||
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
|
||||
the dtype of the frequency tensor.
|
||||
Returns:
|
||||
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
||||
"""
|
||||
assert dim % 2 == 0
|
||||
|
||||
if isinstance(pos, int):
|
||||
pos = torch.arange(pos)
|
||||
if isinstance(pos, np.ndarray):
|
||||
pos = torch.from_numpy(pos) # type: ignore # [S]
|
||||
|
||||
theta = theta * ntk_factor
|
||||
freqs = (
|
||||
1.0
|
||||
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
|
||||
/ linear_factor
|
||||
) # [D/2]
|
||||
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
|
||||
if use_real and repeat_interleave_real:
|
||||
# flux, hunyuan-dit, cogvideox
|
||||
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
|
||||
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
elif use_real:
|
||||
# stable audio
|
||||
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
|
||||
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
else:
|
||||
# lumina
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
x: torch.Tensor,
|
||||
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
||||
use_real: bool = True,
|
||||
use_real_unbind_dim: int = -1,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
||||
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
||||
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
||||
tensors contain rotary embeddings and are returned as real tensors.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`):
|
||||
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
||||
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
||||
"""
|
||||
if use_real:
|
||||
cos, sin = freqs_cis # [S, D]
|
||||
cos = cos[None, None]
|
||||
sin = sin[None, None]
|
||||
cos, sin = cos.to(x.device), sin.to(x.device)
|
||||
|
||||
if use_real_unbind_dim == -1:
|
||||
# Used for flux, cogvideox, hunyuan-dit
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
elif use_real_unbind_dim == -2:
|
||||
# Used for Stable Audio
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
|
||||
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
||||
else:
|
||||
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
||||
|
||||
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
||||
|
||||
return out
|
||||
else:
|
||||
# used for lumina
|
||||
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
||||
freqs_cis = freqs_cis.unsqueeze(2)
|
||||
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
||||
|
||||
return x_out.type_as(x)
|
||||
|
||||
|
||||
class FluxPosEmbed(nn.Module):
|
||||
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
|
||||
def __init__(self, theta: int, axes_dim: List[int]):
|
||||
|
||||
Reference in New Issue
Block a user