Compare commits

...

1 Commits

Author SHA1 Message Date
DN6
818a3b228f update 2025-10-07 17:57:34 +05:30

View File

@@ -341,7 +341,12 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64)
if pos.device.type == "mps":
dtype = torch.float32
else:
dtype = torch.float64
omega = torch.arange(embed_dim // 2, device=pos.device, dtype=dtype)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)