Files
vllm-anthropic/vllm/model_executor/layers/quantization/qutlass_utils.py
Roberto L. Castro 96ad65b7fe [Transform] [Quantization] Add QuTLASS support to vLLM (#24440)
Signed-off-by: LopezCastroRoberto <roberto.lopez.castro@udc.es>
Signed-off-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com>
Signed-off-by: Andrei Panferov <andrei@panferov.org>
Co-authored-by: Andrei Panferov <andrei@panferov.org>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
2025-10-10 09:43:40 -07:00

186 lines
5.5 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#
# Modified by Roberto L. Castro (Roberto.LopezCastro@ist.ac.at).
#
# Copied from https://github.com/pytorch/ao/tree/main/torchao/prototype/mx_formats
#
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Literal
import torch
import triton
import triton.language as tl
from torch.library import wrap_triton
@triton.jit
def triton_scale_swizzle(
scale_ptr: torch.Tensor,
scale_rows: int,
scale_cols: int,
output_ptr: torch.Tensor,
input_row_stride: int,
output_block_stride: int,
BLOCK_ROWS: tl.constexpr,
BLOCK_COLS: tl.constexpr,
):
"""
Rearranges tensor data from row-major to block-scaled swizzle format.
Args:
scale_ptr: Pointer to the input scale tensor
scale_rows: Number of rows in the scale tensor
scale_cols: Number of columns in the scale tensor
output_ptr: Pointer to the output tensor
input_row_stride: Stride between rows in the input tensor
output_block_stride: Stride between blocks in the output tensor
BLOCK_ROWS: Number of rows in a tile (compile-time constant)
BLOCK_COLS: Number of columns in a tile (compile-time constant)
"""
pid_row = tl.program_id(0)
pid_col = tl.program_id(1)
rows = tl.arange(0, BLOCK_ROWS)[:, None]
cols = tl.arange(0, BLOCK_COLS)[None, :]
# Calculate starting row and column for this tile
start_row = pid_row * BLOCK_ROWS
start_col = pid_col * BLOCK_COLS
global_rows = start_row + rows
global_cols = start_col + cols
mask = (global_rows < scale_rows) & (global_cols < scale_cols)
input_scales = tl.load(
scale_ptr + global_rows * input_row_stride + global_cols,
mask=mask,
other=0.0,
)
r_div_32 = rows // 32
r_mod_32 = rows % 32
# 2) Rearrange to (32, 4, 4) then to final (32, 16) coordinates
dest_indices = r_mod_32 * 16 + r_div_32 * 4 + cols
# Flatten
dest_indices_flat = tl.reshape(dest_indices, (BLOCK_ROWS * BLOCK_COLS))
scales_flat = tl.reshape(input_scales, (BLOCK_ROWS * BLOCK_COLS))
# Calculate block offset using provided output block stride
LOCAL_NUMEL = BLOCK_ROWS * BLOCK_COLS
block_offset = pid_col * LOCAL_NUMEL + (pid_row * output_block_stride)
tl.store(
output_ptr + block_offset + dest_indices_flat,
scales_flat,
)
def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
"""
Rearranges an E8M0 tensor scale from row-major format to
block-scaled swizzle format.
This format is suitable for Tmem as described in NVIDIA documentation:
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
Args:
scale_tensor: Input tensor in row-major format with 8-bit elements
Returns:
Rearranged tensor in block-scaled swizzle format
"""
assert scale_tensor.element_size() == 1, (
"Expected element size to be 1 byte (8 bits)"
)
assert scale_tensor.is_contiguous(), "Input tensor must be contiguous"
rows, cols = scale_tensor.shape
# Calculate blocks needed
n_row_blocks = triton.cdiv(rows, 128)
n_col_blocks = triton.cdiv(cols, 4)
padded_rows = n_row_blocks * 128
padded_cols = n_col_blocks * 4
out = scale_tensor.new_empty((padded_rows, padded_cols))
# Input stride (for row-major format)
input_row_stride = cols
# We probably want handle multiple blocks per tile but
# for now keep it simple
BLOCK_ROWS, BLOCK_COLS = 128, 4
# Output block stride for the rearranged format
output_block_stride = BLOCK_ROWS * BLOCK_COLS * (padded_cols // BLOCK_COLS)
grid = lambda META: (
triton.cdiv(padded_rows, BLOCK_ROWS),
triton.cdiv(padded_cols, BLOCK_COLS),
)
wrap_triton(triton_scale_swizzle)[grid](
scale_tensor.view(torch.uint8),
rows,
cols,
out.view(torch.uint8),
input_row_stride,
output_block_stride,
BLOCK_ROWS=BLOCK_ROWS,
BLOCK_COLS=BLOCK_COLS,
)
return out
def ceil_div(a, b):
return (a + b - 1) // b
def to_blocked(
input_matrix: torch.Tensor, backend: Literal["torch", "triton"] = "triton"
) -> torch.Tensor:
"""
Rearrange a large matrix by breaking it into blocks and applying
the rearrangement pattern.
See:
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
Args:
input_matrix: Input tensor of shape (H, W)
backend: "torch" (PyTorch path) or "triton" (Triton kernel)
Returns:
Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4))
"""
if backend == "triton":
return triton_mx_block_rearrange(input_matrix).flatten()
elif backend != "torch":
raise ValueError(f'backend must be "torch" or "triton", got {backend!r}')
rows, cols = input_matrix.shape
n_row_blocks = ceil_div(rows, 128)
n_col_blocks = ceil_div(cols, 4)
# Calculate the padded shape
padded_rows = n_row_blocks * 128
padded_cols = n_col_blocks * 4
padded = input_matrix
assert (rows, cols) == (padded_rows, padded_cols)
# Rearrange the blocks
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
return rearranged.flatten()