Skip to content

[mamf-finder] add mxfp8 support #111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 59 additions & 5 deletions compute/accelerator/benchmarks/mamf-finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,41 @@ def flush(self):
if self.verbose:
self.stdout.flush()

# from https://github.com/pytorch/pytorch/blob/b432443cf2fdcd2575c6e3363d4f86448a5d6650/test/test_matmul_cuda.py#L924
# XXX: hopefully pytorch will have a core function for that
def ceil_div(a, b): return (a + b - 1) // b
def to_blocked(input_matrix) -> torch.Tensor:
"""
Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.

See:
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout

Args:
input_matrix: Input tensor of shape (H, W)

Returns:
Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4))
"""
rows, cols = input_matrix.shape
n_row_blocks = ceil_div(rows, 128)
n_col_blocks = ceil_div(cols, 4)

# Calculate the padded shape
padded_rows = n_row_blocks * 128
padded_cols = n_col_blocks * 4

padded = input_matrix
# Ideally we would use torch.nn.pad but it doesn't support float8_e8m0fnu for now
if (rows, cols) != (padded_rows, padded_cols):
padded = torch.zeros((padded_rows, padded_cols), device=input_matrix.device, dtype=input_matrix.dtype)
padded[:rows, :cols] = input_matrix

# Rearrange the blocks
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)

return rearranged.flatten()

def print_benchmark_header(dtype, device, notes="None"):

Expand Down Expand Up @@ -228,7 +263,7 @@ def func_wrapper(*args, **kwargs):

# fp8 requires special handling depending on the vendor:
# float8_e4m3fn for nvidia, float8_e4m3fnuz for amd
fp8_dtypes = [torch.float8_e4m3fn, torch.float8_e4m3fnuz]
fp8_dtypes = [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e8m0fnu]
if dtype in fp8_dtypes:
# torch._scaled_mm is different before pt-2.5
if version.parse(torch.__version__) < version.parse("2.5"):
Expand All @@ -238,14 +273,33 @@ def func_wrapper(*args, **kwargs):

A = torch.randn(m, k, dtype=torch.float32, device=device).contiguous()
B = torch.randn(n, k, dtype=torch.float32, device=device).contiguous().t()
scale = torch.tensor([1.0]).to(device)
A = A.to(dtype)
B = B.to(dtype)

if dtype == torch.float8_e8m0fnu:
# mxfp8
BLOCK_SIZE = 32
# from https://github.com/pytorch/pytorch/blob/b432443cf2fdcd2575c6e3363d4f86448a5d6650/test/test_matmul_cuda.py#L950-L961

A = A.to(torch.float8_e4m3fn)
B = B.to(torch.float8_e4m3fn)

scale_a = torch.full((m, ceil_div(k, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
scale_b = torch.full((n, ceil_div(k, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
# convert to swizzled format
scale_a = to_blocked(scale_a)
scale_b = to_blocked(scale_b)

out_dtype = torch.torch.bfloat16
else:
scale_a = torch.tensor([1.0]).to(device)
scale_b = scale_a
A = A.to(dtype)
B = B.to(dtype)
out_dtype = dtype

# Simplified call for PyTorch 2.5+
@time_it(total_iterations)
def time_iterations():
C = torch._scaled_mm(A, B, scale, scale)
C = torch._scaled_mm(A, B, scale_a, scale_b, out_dtype=out_dtype)

else:
A = torch.randn(m, k, dtype=dtype, device=device).contiguous()
Expand Down