AI-Generated FP8 GEMMs on AMD MI355X

AI-Generated FP8 GEMMs on AMD MI355X

AI-Generated FP8 GEMMs on AMD MI355X

AI-Generated FP8 GEMMs on AMD MI355X

Writing fast FP8 GEMMs just got faster

Writing fast FP8 GEMMs just got faster

Written by

Cătălin Milu

Cătălin Milu

Published on

Apr 16, 2026

Apr 16, 2026

TLDR: This is part 1 of a series on low-precision GEMM kernels on MI355X. We use MakoraGenerate to write fast FP8 GEMMs in HIP. We go over what makes FP8 easy and hard, and release code along with performance results. The resulting kernels beat the publicly available AITER library provided by AMD on a variety of shapes. We'll be exploring FP4 and FP6 next, along with some clever features of the hardware we can exploit.

FP8 refresher

FP8 is an 8-bit floating point family of formats used to increase throughput and reduce memory bandwidth pressure compared to FP16 and BF16. The exact dynamic range and precision depend on the specific FP8 encoding. Two FP8 encodings are common in practice:

  • E4M3 (more mantissa precision, smaller range)

  • E5M2 (larger range, less precision)

FP8 GEMM is typically paired with explicit scaling. Instead of treating FP8 as a drop in replacement for FP16, you treat it as a compressed representation and recover a usable numeric range via scale factors.

Reference: NVIDIA

Achieving high performance

Most high performance FP8 GEMMs use one of these scaling schemes:

  • Per tensor scaling

  • Per row and per column scaling

  • Block scaling (scales per M block and N block, and sometimes per K block)

Block scaling represents a strategic compromise between accuracy and overhead, offering higher precision than global scaling with lower computational costs than per-element scaling. The MakoraGenerate must reason about these scaling factors because they directly influence both numerical correctness and kernel performance. Specifically, scaling decisions dictate how values are loaded, cached, and applied within the epilogue, or whether they are fused directly into the main compute loop.

Beyond raw compute, performance is governed by the interplay of data movement, memory layouts, tiling schedules, and precision strategies. Since an optimal kernel requires simultaneous shape and hardware awareness, a generate-and-evaluate loop is necessary to navigate these complex architectural trade-offs.

The MI355X (CDNA 4) targets very high peak throughput. In AMD’s published architecture material, peak FP8 throughput is shown at roughly 5 PFLOPs, around 1.9x faster than MI300X (CDNA 3).

AMD CDNA 4 architecture whitepaper:

https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/white-papers/amd-cdna-4-architecture-whitepaper.pdf

Benchmark setup and results

  • Hardware: AMD MI355X

  • ROCm: 7.0

  • Kernel type: FP8 GEMM

  • Baseline: reference implementation from AITER

This section summarizes benchmark results and highlights the cases where the agent generated kernels match or exceed the baseline.

Notes on what improved

Across iterations, generation time decreased significantly while performance increased. The agent learns from previous experiments, reuses what works, and keeps refining the kernel.

Average generation time per iteration:

  • First run: about 6 hours

  • Second run: about 2 hours

  • Third run: about 30 minutes to a state of the art kernel

The key takeaway is that the optimization loop is getting faster and more effective over time, as the agent accumulates prior knowledge and converges with fewer iterations.

Code highlights

The snippets below are examples of the kinds of transformations the agent produces when optimizing FP8 GEMM. I will keep adding excerpts here.

This excerpt shows an MFMA inner loop with four independent MFMA calls to increase instruction level parallelism. After the MFMA loop, the code applies A and B scale factors before accumulating into FP32.

Key ideas to notice:

  • Multiple independent MFMA calls to keep the pipeline busy

  • Packing FP8 into 8 byte chunks

  • Scale application fused close to accumulation

// MFMA inner loop: TILE_K / MFMA_K = 64 / 16 = 4 steps
tmp00 = {0.f,0.f,0.f,0.f,0.f,0.f,0.f,0.f,0.f,0.f,0.f,0.f,0.f,0.f,0.f,0.f};
tmp01 = {0.f,0.f,0.f,0.f,0.f,0.f,0.f,0.f,0.f,0.f,0.f,0.f,0.f,0.f,0.f,0.f};
tmp10 = {0.f,0.f,0.f,0.f,0.f,0.f,0.f,0.f,0.f,0.f,0.f,0.f,0.f,0.f,0.f,0.f};
tmp11 = {0.f,0.f,0.f,0.f,0.f,0.f,0.f,0.f,0.f,0.f,0.f,0.f,0.f,0.f,0.f,0.f};

// For MFMA 32x32x16 with FP8:
// Each lane provides 8 bytes of A and 8 bytes of B.
#pragma unroll 4
for (int k_inner = 0; k_inner < TILE_K; k_inner += MFMA_K) {
    const int k_off = k_inner + k_group;

    long long a0 = pack8_fp8(&a_smem[(warp_m +      mfma_row) * TILE_K + k_off]);
    long long a1 = pack8_fp8(&a_smem[(warp_m + 32 + mfma_row) * TILE_K + k_off]);
    long long b0 = pack8_fp8(&b_smem[(warp_n +      mfma_row) * TILE_K + k_off]);
    long long b1 = pack8_fp8(&b_smem[(warp_n + 32 + mfma_row) * TILE_K + k_off]);

    // Four independent MFMA calls for more ILP
    tmp00 = mfma_fp8_32x32x16(a0, b0, tmp00);
    tmp01 = mfma_fp8_32x32x16(a0, b1, tmp01);
    tmp10 = mfma_fp8_32x32x16(a1, b0, tmp10);
    tmp11 = mfma_fp8_32x32x16(a1, b1, tmp11);
}

// Apply block scales and accumulate into FP32 accumulators
const float b_sv0 = b_scale[b_scale_n_base0 * scale_K + scale_k_idx] * 0.25f;
const float b_sv1 = b_scale[b_scale_n_base1 * scale_K + scale_k_idx] * 0.25f;

#pragma unroll 16
for (int v = 0; v < 16; v++) {
    const int row_in_sub = (v / 4) * 8 + blk_id * 4 + (v % 4);
    const int gm0 = m_block + warp_m +      row_in_sub;
    const int gm1 = m_block + warp_m + 32 + row_in_sub;
    const float a_sv0 = (gm0 < M) ? a_scale[gm0 * scale_K + scale_k_idx] : 0.f;
    const float a_sv1 = (gm1 < M) ? a_scale[gm1 * scale_K + scale_k_idx] : 0.f;
    acc00[v] = fmaf(tmp00[v], a_sv0 * b_sv0, acc00[v]);
    acc01[v] = fmaf(tmp01[v], a_sv0 * b_sv1, acc01[v]);
    acc10[v] = fmaf(tmp10[v], a_sv1 * b_sv0, acc10[v]);
    acc11[v] = fmaf(tmp11[v], a_sv1 * b_sv1, acc11[v]);
}

Split K, fused heuristics, and dtype details

This excerpt documents a subtle FP8 detail and the correction applied in the kernel.

// fp8_e4m3fnuz vs fp8_e4m3 bias correction:
// MFMA interprets fp8_e4m3 (bias = 7), but data is fp8_e4m3fnuz (bias = 8).
// Each operand is 2x too large, so the product is 4x too large.
// Correction: multiply by 0.25f before applying per row and per column scales.

And the corresponding accumulation into a temporary FP32 buffer for split K:

// Atomic add into FP32 temp buffer (with bias correction)
int c_col_local = lane_id % MFMA_N;
int c_row_base  = (lane_id / MFMA_N) * 4;
int gn          = block_n + wave_n_off + c_col_local;

if (gn < N) {
    // Top tile
    #pragma unroll
    for (int blk = 0; blk < 4; blk++) {
        int row_off = blk * 8;
        #pragma unroll
        for (int elem = 0; elem < 4; elem++) {
            int gm = block_m + c_row_base + row_off + elem;
            if (gm < M) {
                float val = acc_top[blk * 4 + elem] * BIAS_CORRECTION;
                atomicAdd(&temp[gm * N + gn], val);
            }
        }
    }
    // Bottom tile
    #pragma unroll
    for (int blk = 0; blk < 4; blk++) {
        int row_off = blk * 8;
        #pragma unroll
        for (int elem = 0; elem < 4; elem++) {
            int gm = block_m + MFMA_M + c_row_base + row_off + elem;
            if (gm < M) {
                float val = acc_bot[blk * 4 + elem] * BIAS_CORRECTION;
                atomicAdd(&temp[gm * N + gn], val);
            }
        }
    }
}

Copyright © 2026 MakoRA. All rights reserved.

Copyright © 2026 MakoRA. All rights reserved.

Copyright © 2026 MakoRA. All rights reserved.

Copyright © 2026 MakoRA. All rights reserved.