Signed-off-by: LopezCastroRoberto <roberto.lopez.castro@udc.es> Signed-off-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com> Signed-off-by: Andrei Panferov <andrei@panferov.org> Co-authored-by: Andrei Panferov <andrei@panferov.org> Co-authored-by: Michael Goin <mgoin64@gmail.com>
186 lines
5.5 KiB
Python
186 lines
5.5 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
#
|
|
# Modified by Roberto L. Castro (Roberto.LopezCastro@ist.ac.at).
|
|
#
|
|
# Copied from https://github.com/pytorch/ao/tree/main/torchao/prototype/mx_formats
|
|
#
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from typing import Literal
|
|
|
|
import torch
|
|
import triton
|
|
import triton.language as tl
|
|
from torch.library import wrap_triton
|
|
|
|
|
|
@triton.jit
|
|
def triton_scale_swizzle(
|
|
scale_ptr: torch.Tensor,
|
|
scale_rows: int,
|
|
scale_cols: int,
|
|
output_ptr: torch.Tensor,
|
|
input_row_stride: int,
|
|
output_block_stride: int,
|
|
BLOCK_ROWS: tl.constexpr,
|
|
BLOCK_COLS: tl.constexpr,
|
|
):
|
|
"""
|
|
Rearranges tensor data from row-major to block-scaled swizzle format.
|
|
|
|
Args:
|
|
scale_ptr: Pointer to the input scale tensor
|
|
scale_rows: Number of rows in the scale tensor
|
|
scale_cols: Number of columns in the scale tensor
|
|
output_ptr: Pointer to the output tensor
|
|
input_row_stride: Stride between rows in the input tensor
|
|
output_block_stride: Stride between blocks in the output tensor
|
|
BLOCK_ROWS: Number of rows in a tile (compile-time constant)
|
|
BLOCK_COLS: Number of columns in a tile (compile-time constant)
|
|
"""
|
|
pid_row = tl.program_id(0)
|
|
pid_col = tl.program_id(1)
|
|
|
|
rows = tl.arange(0, BLOCK_ROWS)[:, None]
|
|
cols = tl.arange(0, BLOCK_COLS)[None, :]
|
|
|
|
# Calculate starting row and column for this tile
|
|
start_row = pid_row * BLOCK_ROWS
|
|
start_col = pid_col * BLOCK_COLS
|
|
global_rows = start_row + rows
|
|
global_cols = start_col + cols
|
|
|
|
mask = (global_rows < scale_rows) & (global_cols < scale_cols)
|
|
|
|
input_scales = tl.load(
|
|
scale_ptr + global_rows * input_row_stride + global_cols,
|
|
mask=mask,
|
|
other=0.0,
|
|
)
|
|
|
|
r_div_32 = rows // 32
|
|
r_mod_32 = rows % 32
|
|
|
|
# 2) Rearrange to (32, 4, 4) then to final (32, 16) coordinates
|
|
dest_indices = r_mod_32 * 16 + r_div_32 * 4 + cols
|
|
|
|
# Flatten
|
|
dest_indices_flat = tl.reshape(dest_indices, (BLOCK_ROWS * BLOCK_COLS))
|
|
scales_flat = tl.reshape(input_scales, (BLOCK_ROWS * BLOCK_COLS))
|
|
|
|
# Calculate block offset using provided output block stride
|
|
LOCAL_NUMEL = BLOCK_ROWS * BLOCK_COLS
|
|
block_offset = pid_col * LOCAL_NUMEL + (pid_row * output_block_stride)
|
|
|
|
tl.store(
|
|
output_ptr + block_offset + dest_indices_flat,
|
|
scales_flat,
|
|
)
|
|
|
|
|
|
def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Rearranges an E8M0 tensor scale from row-major format to
|
|
block-scaled swizzle format.
|
|
|
|
This format is suitable for Tmem as described in NVIDIA documentation:
|
|
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
|
|
|
|
Args:
|
|
scale_tensor: Input tensor in row-major format with 8-bit elements
|
|
|
|
Returns:
|
|
Rearranged tensor in block-scaled swizzle format
|
|
"""
|
|
assert scale_tensor.element_size() == 1, (
|
|
"Expected element size to be 1 byte (8 bits)"
|
|
)
|
|
assert scale_tensor.is_contiguous(), "Input tensor must be contiguous"
|
|
|
|
rows, cols = scale_tensor.shape
|
|
|
|
# Calculate blocks needed
|
|
n_row_blocks = triton.cdiv(rows, 128)
|
|
n_col_blocks = triton.cdiv(cols, 4)
|
|
padded_rows = n_row_blocks * 128
|
|
padded_cols = n_col_blocks * 4
|
|
|
|
out = scale_tensor.new_empty((padded_rows, padded_cols))
|
|
|
|
# Input stride (for row-major format)
|
|
input_row_stride = cols
|
|
|
|
# We probably want handle multiple blocks per tile but
|
|
# for now keep it simple
|
|
BLOCK_ROWS, BLOCK_COLS = 128, 4
|
|
|
|
# Output block stride for the rearranged format
|
|
output_block_stride = BLOCK_ROWS * BLOCK_COLS * (padded_cols // BLOCK_COLS)
|
|
|
|
grid = lambda META: (
|
|
triton.cdiv(padded_rows, BLOCK_ROWS),
|
|
triton.cdiv(padded_cols, BLOCK_COLS),
|
|
)
|
|
|
|
wrap_triton(triton_scale_swizzle)[grid](
|
|
scale_tensor.view(torch.uint8),
|
|
rows,
|
|
cols,
|
|
out.view(torch.uint8),
|
|
input_row_stride,
|
|
output_block_stride,
|
|
BLOCK_ROWS=BLOCK_ROWS,
|
|
BLOCK_COLS=BLOCK_COLS,
|
|
)
|
|
|
|
return out
|
|
|
|
|
|
def ceil_div(a, b):
|
|
return (a + b - 1) // b
|
|
|
|
|
|
def to_blocked(
|
|
input_matrix: torch.Tensor, backend: Literal["torch", "triton"] = "triton"
|
|
) -> torch.Tensor:
|
|
"""
|
|
Rearrange a large matrix by breaking it into blocks and applying
|
|
the rearrangement pattern.
|
|
|
|
See:
|
|
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
|
|
|
|
Args:
|
|
input_matrix: Input tensor of shape (H, W)
|
|
backend: "torch" (PyTorch path) or "triton" (Triton kernel)
|
|
|
|
Returns:
|
|
Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4))
|
|
"""
|
|
if backend == "triton":
|
|
return triton_mx_block_rearrange(input_matrix).flatten()
|
|
elif backend != "torch":
|
|
raise ValueError(f'backend must be "torch" or "triton", got {backend!r}')
|
|
|
|
rows, cols = input_matrix.shape
|
|
n_row_blocks = ceil_div(rows, 128)
|
|
n_col_blocks = ceil_div(cols, 4)
|
|
|
|
# Calculate the padded shape
|
|
padded_rows = n_row_blocks * 128
|
|
padded_cols = n_col_blocks * 4
|
|
|
|
padded = input_matrix
|
|
assert (rows, cols) == (padded_rows, padded_cols)
|
|
|
|
# Rearrange the blocks
|
|
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
|
|
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
|
|
|
|
return rearranged.flatten()
|