2D Array Addition

This kernel extends the simplest Triton kernel to perform element-wise 2D array addition. Instead of adding two single elements or 1D vectors, it operates on multiple pairs of elements in parallel using multiple threads. Let’s break it down and highlight the differences.


Kernel Code Breakdown

def array_add_kernel(a_ptr, b_ptr, c_ptr, n_rows: tl.constexpr, n_cols: tl.constexpr):

This defines the kernel function. It takes five arguments:

  • a_ptr: A pointer to the first input 2D array.
  • b_ptr: A pointer to the second input 2D array.
  • c_ptr: A pointer to the output 2D array.
  • n_rows: The number of rows in the arrays, passed as a constant.
  • n_cols: The number of columns in the arrays, passed as a constant.
row_idx = tl.program_id(0)
col_idx = tl.program_id(1)

tl.program_id(0) retrieves the unique thread index for the current thread along the row axis, while tl.program_id(1) retrieves the thread index along the column axis. Each thread processes a unique element based on these indices.

if row_idx < n_rows and col_idx < n_cols:

This ensures that threads with indices greater than the number of rows or columns do not perform out-of-bounds memory operations.

index = row_idx * n_cols + col_idx

Computes the linearized index for accessing elements in a flattened representation of the 2D array.

a = tl.load(a_ptr + index)

Loads the corresponding element of the array pointed to by a_ptr into a register for computation.

b = tl.load(b_ptr + index)

Similarly, loads the corresponding element of the array pointed to by b_ptr.

tl.store(c_ptr + index, a + b)

Adds the loaded values a and b and stores the result in the corresponding position of the output array.


Complete Kernel Code

Below is the complete kernel along with the host-side code to call it:

import triton
import triton.language as tl

@triton.jit
def array_add_kernel(a_ptr, b_ptr, c_ptr, n_rows: tl.constexpr, n_cols: tl.constexpr):
    row_idx = tl.program_id(0)  # Get row index
    col_idx = tl.program_id(1)  # Get column index
    if row_idx < n_rows and col_idx < n_cols:  # Bounds check
        index = row_idx * n_cols + col_idx  # Compute linear index
        a = tl.load(a_ptr + index)  # Load element from input array A
        b = tl.load(b_ptr + index)  # Load element from input array B
        tl.store(c_ptr + index, a + b)  # Perform addition and store the result

# Host-side code
import torch

n_rows, n_cols = 4, 4  # Dimensions of the arrays
a = torch.arange(1, n_rows * n_cols + 1, device="cuda", dtype=torch.float32).reshape(n_rows, n_cols)  # 2D Array A
b = torch.arange(1, n_rows * n_cols + 1, device="cuda", dtype=torch.float32).reshape(n_rows, n_cols)  # 2D Array B
c = torch.empty_like(a)  # Output array

grid_size = (n_rows, n_cols)  # One thread per element
array_add_kernel[grid_size](a, b, c, n_rows=n_rows, n_cols=n_cols)

print(f"Result:\n{c}")

Illustrated Example: 2×2 Array Addition in Triton

Given Input Matrices

A = [[1, 2],
     [3, 4]]

B = [[5, 6],
     [7, 8]]

Flattened in Memory (Row-Major Order):

A (Memory Address: 1000) → [1, 2, 3, 4]
B (Memory Address: 2000) → [5, 6, 7, 8]
C (Memory Address: 3000) → [?, ?, ?, ?] (To store results)

1️⃣ Assign Thread IDs

Each thread processes one element:

Thread(row_idx, col_idx)index = row_idx * 2 + col_idx
T0(0,0)0
T1(0,1)1
T2(1,0)2
T3(1,1)3

2️⃣ Compute Memory Addresses

Using (base address + index), each thread accesses memory:

ThreadA AddressB AddressC Address
T01000 + 0 = 10002000 + 0 = 20003000 + 0 = 3000
T11000 + 1 = 10012000 + 1 = 20013000 + 1 = 3001
T21000 + 2 = 10022000 + 2 = 20023000 + 2 = 3002
T31000 + 3 = 10032000 + 3 = 20033000 + 3 = 3003

3️⃣ Load Values from Memory

Each thread loads values from A and B:

ThreadA[index]B[index]
T015
T126
T237
T348

4️⃣ Perform Element-wise Addition

Each thread computes:

C[index] = A[index] + B[index]
ThreadC[index] CalculationResult Stored in C
T01 + 5 = 6C[0] = 6
T12 + 6 = 8C[1] = 8
T23 + 7 = 10C[2] = 10
T34 + 8 = 12C[3] = 12

5️⃣ Store Results in C

Final memory layout for C:

C (Memory Address: 3000) → [6, 8, 10, 12]

Final Output Matrix

C = [[6,  8],
     [10, 12]]

🎯 Each thread handled one element, computed the sum, and stored it back efficiently in parallel! 🚀

Key Differences from 1D Vector Addition

  1. 2D Thread Indexing: Instead of a single index, we use (row_idx, col_idx), requiring two tl.program_id() calls.
  2. Bounds Checking: We check both row and column bounds using if row_idx < n_rows and col_idx < n_cols.
  3. Linear Index Calculation: Instead of direct indexing, we compute index = row_idx * n_cols + col_idx to flatten the 2D access pattern.
  4. Grid Dimensions: The grid size is now (n_rows, n_cols), ensuring one thread per element in the 2D array.

This kernel demonstrates how simple vector operations can be extended to multi-dimensional workloads, leveraging Triton’s parallelism efficiently. 🚀