Matrix Multiplication
This kernel extends the simplest Triton kernel to perform matrix multiplication.
This is the foundational kernel upon which all great kernels are built every advanced kernel is an extension of this one.
It operates on entire matrices in parallel using multiple threads. Let’s break it down and highlight the differences.
Kernel Code Breakdown
index_c = row_idx * n_cols + col_idx
sum = 0.0
for k in range(k_dim):
index_a = row_idx * k_dim + k
index_b = k * n_cols + col_idx
a = tl.load(a_ptr + index_a)
b = tl.load(b_ptr + index_b)
sum += a * b
Key Difference: Matrix Multiplication Loop
- Why Two Different Indexes?
index_a = row_idx * k_dim + k
accesses elements from row ( i ) of matrix A.index_b = k * n_cols + col_idx
accesses elements from column ( j ) of matrix B.- Each thread computes one output element by iterating over the k_dim.
- The for loop performs:
- Loads: Fetching corresponding elements from A and B.
- Multiply: Computing ( A[i,k] imes B[k,j] ).
- Accumulate: Summing up the products to form ( C[i,j] ).
tl.store(c_ptr + index_c, sum)
Stores the computed sum in the output matrix C
.
Complete Kernel Code
Below is the complete kernel along with the host-side code to call it:
import triton
import triton.language as tl
@triton.jit
def matmul_kernel(a_ptr, b_ptr, c_ptr, n_rows: tl.constexpr, n_cols: tl.constexpr, k_dim: tl.constexpr):
row_idx = tl.program_id(0) # Get row index
col_idx = tl.program_id(1) # Get column index
if row_idx < n_rows and col_idx < n_cols: # Bounds check
index_c = row_idx * n_cols + col_idx # Compute linear index for C
sum = 0.0
for k in range(k_dim): # Loop over shared dimension
index_a = row_idx * k_dim + k
index_b = k * n_cols + col_idx
a = tl.load(a_ptr + index_a) # Load element from A
b = tl.load(b_ptr + index_b) # Load element from B
sum += a * b # Multiply and accumulate
tl.store(c_ptr + index_c, sum) # Store result in C
# Host-side code
import torch
n_rows, n_cols, k_dim = 4, 4, 4 # Dimensions of the matrices
A = torch.arange(1, n_rows * k_dim + 1, device="cuda", dtype=torch.float32).reshape(n_rows, k_dim) # Matrix A
B = torch.arange(1, k_dim * n_cols + 1, device="cuda", dtype=torch.float32).reshape(k_dim, n_cols) # Matrix B
C = torch.empty(n_rows, n_cols, device="cuda", dtype=torch.float32) # Output matrix
grid_size = (n_rows, n_cols) # One thread per element in C
matmul_kernel[grid_size](A, B, C, n_rows=n_rows, n_cols=n_cols, k_dim=k_dim)
print(f"Result:\n{C}")
Key Differences from 2D Array Addition
- Matrix Multiplication Loop: Unlike addition, each output element requires summing the product of corresponding elements in a row from A and a column from B.
- Linear Index Calculation: We compute two separate indices for A and B (
index_a
andindex_b
). - Grid Dimensions: The grid size remains
(n_rows, n_cols)
, ensuring one thread per element in C.
This kernel demonstrates how element-wise operations can be extended to matrix multiplication, leveraging Triton’s parallelism efficiently. 🚀