Compare commits

...

2 Commits

Author SHA1 Message Date
sayakpaul
805d5133f7 start supporting sage blackwell. 2026-03-06 10:12:03 +05:30
Shenghai Yuan
46bd005730 Convert tensors to float in Helios's optimized_scale function (#13214)
Convert tensors to float in optimized_scale function
2026-03-05 17:08:44 -10:00
2 changed files with 11 additions and 2 deletions

View File

@@ -703,9 +703,16 @@ def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
try:
from kernels import get_kernel
kernel_module = get_kernel(config.repo_id, revision=config.revision, version=config.version)
is_blackwell = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 10 # sm_100+
version = 2 if is_blackwell and backend == AttentionBackendName.SAGE_HUB else config.version
kernel_module = get_kernel(config.repo_id, revision=config.revision, version=version)
if needs_kernel:
config.kernel_fn = _resolve_kernel_attr(kernel_module, config.function_attr)
function_attr = (
"sageattn3_blackwell"
if is_blackwell and backend == AttentionBackendName.SAGE_HUB
else config.function_attr
)
config.kernel_fn = _resolve_kernel_attr(kernel_module, function_attr)
if needs_wrapped_forward:
config.wrapped_forward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_forward_attr)

View File

@@ -76,6 +76,8 @@ EXAMPLE_DOC_STRING = """
def optimized_scale(positive_flat, negative_flat):
positive_flat = positive_flat.float()
negative_flat = negative_flat.float()
# Calculate dot production
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
# Squared norm of uncondition