Compare commits

...

1 Commits

Author SHA1 Message Date
sayakpaul
805d5133f7 start supporting sage blackwell. 2026-03-06 10:12:03 +05:30

View File

@@ -703,9 +703,16 @@ def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
try:
from kernels import get_kernel
kernel_module = get_kernel(config.repo_id, revision=config.revision, version=config.version)
is_blackwell = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 10 # sm_100+
version = 2 if is_blackwell and backend == AttentionBackendName.SAGE_HUB else config.version
kernel_module = get_kernel(config.repo_id, revision=config.revision, version=version)
if needs_kernel:
config.kernel_fn = _resolve_kernel_attr(kernel_module, config.function_attr)
function_attr = (
"sageattn3_blackwell"
if is_blackwell and backend == AttentionBackendName.SAGE_HUB
else config.function_attr
)
config.kernel_fn = _resolve_kernel_attr(kernel_module, function_attr)
if needs_wrapped_forward:
config.wrapped_forward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_forward_attr)