2D Array Addition
This kernel extends the simplest Triton kernel to perform element-wise 2D array addition. Instead of adding two single elements or 1D vectors, it operates on multiple pairs of elements in parallel using multiple threads. Let’s break it down and highlight the differences.
Kernel Code Breakdown
def array_add_kernel(a_ptr, b_ptr, c_ptr, n_rows: tl.constexpr, n_cols: tl.constexpr):
This defines the kernel function. It takes five arguments:
a_ptr
: A pointer to the first input 2D array.b_ptr
: A pointer to the second input 2D array.c_ptr
: A pointer to the output 2D array.n_rows
: The number of rows in the arrays, passed as a constant.n_cols
: The number of columns in the arrays, passed as a constant.
row_idx = tl.program_id(0)
col_idx = tl.program_id(1)
tl.program_id(0)
retrieves the unique thread index for the current thread along the row axis, while tl.program_id(1)
retrieves the thread index along the column axis. Each thread processes a unique element based on these indices.
if row_idx < n_rows and col_idx < n_cols:
This ensures that threads with indices greater than the number of rows or columns do not perform out-of-bounds memory operations.
index = row_idx * n_cols + col_idx
Computes the linearized index for accessing elements in a flattened representation of the 2D array.
a = tl.load(a_ptr + index)
Loads the corresponding element of the array pointed to by a_ptr
into a register for computation.
b = tl.load(b_ptr + index)
Similarly, loads the corresponding element of the array pointed to by b_ptr
.
tl.store(c_ptr + index, a + b)
Adds the loaded values a
and b
and stores the result in the corresponding position of the output array.
Complete Kernel Code
Below is the complete kernel along with the host-side code to call it:
import triton
import triton.language as tl
@triton.jit
def array_add_kernel(a_ptr, b_ptr, c_ptr, n_rows: tl.constexpr, n_cols: tl.constexpr):
row_idx = tl.program_id(0) # Get row index
col_idx = tl.program_id(1) # Get column index
if row_idx < n_rows and col_idx < n_cols: # Bounds check
index = row_idx * n_cols + col_idx # Compute linear index
a = tl.load(a_ptr + index) # Load element from input array A
b = tl.load(b_ptr + index) # Load element from input array B
tl.store(c_ptr + index, a + b) # Perform addition and store the result
# Host-side code
import torch
n_rows, n_cols = 4, 4 # Dimensions of the arrays
a = torch.arange(1, n_rows * n_cols + 1, device="cuda", dtype=torch.float32).reshape(n_rows, n_cols) # 2D Array A
b = torch.arange(1, n_rows * n_cols + 1, device="cuda", dtype=torch.float32).reshape(n_rows, n_cols) # 2D Array B
c = torch.empty_like(a) # Output array
grid_size = (n_rows, n_cols) # One thread per element
array_add_kernel[grid_size](a, b, c, n_rows=n_rows, n_cols=n_cols)
print(f"Result:\n{c}")
Illustrated Example: 2×2 Array Addition in Triton
Given Input Matrices
A = [[1, 2],
[3, 4]]
B = [[5, 6],
[7, 8]]
Flattened in Memory (Row-Major Order):
A (Memory Address: 1000) → [1, 2, 3, 4]
B (Memory Address: 2000) → [5, 6, 7, 8]
C (Memory Address: 3000) → [?, ?, ?, ?] (To store results)
1️⃣ Assign Thread IDs
Each thread processes one element:
Thread | (row_idx, col_idx) | index = row_idx * 2 + col_idx |
---|---|---|
T0 | (0,0) | 0 |
T1 | (0,1) | 1 |
T2 | (1,0) | 2 |
T3 | (1,1) | 3 |
2️⃣ Compute Memory Addresses
Using (base address + index), each thread accesses memory:
Thread | A Address | B Address | C Address |
---|---|---|---|
T0 | 1000 + 0 = 1000 | 2000 + 0 = 2000 | 3000 + 0 = 3000 |
T1 | 1000 + 1 = 1001 | 2000 + 1 = 2001 | 3000 + 1 = 3001 |
T2 | 1000 + 2 = 1002 | 2000 + 2 = 2002 | 3000 + 2 = 3002 |
T3 | 1000 + 3 = 1003 | 2000 + 3 = 2003 | 3000 + 3 = 3003 |
3️⃣ Load Values from Memory
Each thread loads values from A and B:
Thread | A[index] | B[index] |
---|---|---|
T0 | 1 | 5 |
T1 | 2 | 6 |
T2 | 3 | 7 |
T3 | 4 | 8 |
4️⃣ Perform Element-wise Addition
Each thread computes:
C[index] = A[index] + B[index]
Thread | C[index] Calculation | Result Stored in C |
---|---|---|
T0 | 1 + 5 = 6 | C[0] = 6 |
T1 | 2 + 6 = 8 | C[1] = 8 |
T2 | 3 + 7 = 10 | C[2] = 10 |
T3 | 4 + 8 = 12 | C[3] = 12 |
C
5️⃣ Store Results in Final memory layout for C:
C (Memory Address: 3000) → [6, 8, 10, 12]
Final Output Matrix
C = [[6, 8],
[10, 12]]
🎯 Each thread handled one element, computed the sum, and stored it back efficiently in parallel! 🚀
Key Differences from 1D Vector Addition
- 2D Thread Indexing: Instead of a single index, we use
(row_idx, col_idx)
, requiring twotl.program_id()
calls. - Bounds Checking: We check both row and column bounds using
if row_idx < n_rows and col_idx < n_cols
. - Linear Index Calculation: Instead of direct indexing, we compute
index = row_idx * n_cols + col_idx
to flatten the 2D access pattern. - Grid Dimensions: The grid size is now
(n_rows, n_cols)
, ensuring one thread per element in the 2D array.
This kernel demonstrates how simple vector operations can be extended to multi-dimensional workloads, leveraging Triton’s parallelism efficiently. 🚀