Matrix Multiplication

This kernel extends the simplest Triton kernel to perform matrix multiplication.

This is the foundational kernel upon which all great kernels are built every advanced kernel is an extension of this one.

It operates on entire matrices in parallel using multiple threads. Let’s break it down and highlight the differences.


Kernel Code Breakdown

index_c = row_idx * n_cols + col_idx
sum = 0.0
for k in range(k_dim):
    index_a = row_idx * k_dim + k
    index_b = k * n_cols + col_idx
    a = tl.load(a_ptr + index_a)
    b = tl.load(b_ptr + index_b)
    sum += a * b

Key Difference: Matrix Multiplication Loop

  • Why Two Different Indexes?
    • index_a = row_idx * k_dim + k accesses elements from row ( i ) of matrix A.
    • index_b = k * n_cols + col_idx accesses elements from column ( j ) of matrix B.
    • Each thread computes one output element by iterating over the k_dim.
  • The for loop performs:
    • Loads: Fetching corresponding elements from A and B.
    • Multiply: Computing ( A[i,k] imes B[k,j] ).
    • Accumulate: Summing up the products to form ( C[i,j] ).
tl.store(c_ptr + index_c, sum)

Stores the computed sum in the output matrix C.


Complete Kernel Code

Below is the complete kernel along with the host-side code to call it:

import triton
import triton.language as tl

@triton.jit
def matmul_kernel(a_ptr, b_ptr, c_ptr, n_rows: tl.constexpr, n_cols: tl.constexpr, k_dim: tl.constexpr):
    row_idx = tl.program_id(0)  # Get row index
    col_idx = tl.program_id(1)  # Get column index
    if row_idx < n_rows and col_idx < n_cols:  # Bounds check
        index_c = row_idx * n_cols + col_idx  # Compute linear index for C
        sum = 0.0
        for k in range(k_dim):  # Loop over shared dimension
            index_a = row_idx * k_dim + k
            index_b = k * n_cols + col_idx
            a = tl.load(a_ptr + index_a)  # Load element from A
            b = tl.load(b_ptr + index_b)  # Load element from B
            sum += a * b  # Multiply and accumulate
        tl.store(c_ptr + index_c, sum)  # Store result in C

# Host-side code
import torch

n_rows, n_cols, k_dim = 4, 4, 4  # Dimensions of the matrices
A = torch.arange(1, n_rows * k_dim + 1, device="cuda", dtype=torch.float32).reshape(n_rows, k_dim)  # Matrix A
B = torch.arange(1, k_dim * n_cols + 1, device="cuda", dtype=torch.float32).reshape(k_dim, n_cols)  # Matrix B
C = torch.empty(n_rows, n_cols, device="cuda", dtype=torch.float32)  # Output matrix

grid_size = (n_rows, n_cols)  # One thread per element in C
matmul_kernel[grid_size](A, B, C, n_rows=n_rows, n_cols=n_cols, k_dim=k_dim)

print(f"Result:\n{C}")

Key Differences from 2D Array Addition

  1. Matrix Multiplication Loop: Unlike addition, each output element requires summing the product of corresponding elements in a row from A and a column from B.
  2. Linear Index Calculation: We compute two separate indices for A and B (index_a and index_b).
  3. Grid Dimensions: The grid size remains (n_rows, n_cols), ensuring one thread per element in C.

This kernel demonstrates how element-wise operations can be extended to matrix multiplication, leveraging Triton’s parallelism efficiently. 🚀