Vector Addition

This kernel extends the simplest Triton kernel to perform element-wise vector addition. Instead of adding two single elements, it operates on multiple pairs of elements in parallel using multiple threads. Let’s break it down and highlight the differences.


Kernel Code Breakdown

def vector_add_kernel(a_ptr, b_ptr, c_ptr, n_elements: tl.constexpr):

This defines the kernel function. It takes four arguments:

  • a_ptr: A pointer to the first input vector.
  • b_ptr: A pointer to the second input vector.
  • c_ptr: A pointer to the output vector.
  • n_elements: The number of elements to process, passed as a constant.
idx = tl.program_id(0)

tl.program_id(0) retrieves the unique thread index for the current thread along the 0th axis of the grid. Each thread processes a unique element based on this index.

if idx < n_elements:

This ensures that threads with an index greater than the number of elements do not perform out-of-bounds memory operations.

a = tl.load(a_ptr + idx)

Loads the idx-th element of the vector pointed to by a_ptr into a register for computation.

b = tl.load(b_ptr + idx)

Similarly, loads the idx-th element of the vector pointed to by b_ptr.

tl.store(c_ptr + idx, a + b)

Adds the loaded values a and b and stores the result in the idx-th position of the output vector.


Complete Kernel Code

Below is the complete kernel along with the host-side code to call it:

import triton
import triton.language as tl

@triton.jit
def vector_add_kernel(a_ptr, b_ptr, c_ptr, n_elements: tl.constexpr):
    idx = tl.program_id(0)  # Get the thread index
    if idx < n_elements:  # Bounds check
        a = tl.load(a_ptr + idx)  # Load element from input vector A
        b = tl.load(b_ptr + idx)  # Load element from input vector B
        tl.store(c_ptr + idx, a + b)  # Perform addition and store the result

# Host-side code
import torch

n_elements = 8  # Number of elements in the vectors
a = torch.arange(1, n_elements + 1, device="cuda", dtype=torch.float32)  # Vector A: [1, 2, ..., 8]
b = torch.arange(1, n_elements + 1, device="cuda", dtype=torch.float32)  # Vector B: [1, 2, ..., 8]
c = torch.empty_like(a)  # Output vector

grid_size = (n_elements,)  # One thread per element
vector_add_kernel[grid_size](a.data_ptr(), b.data_ptr(), c.data_ptr(), n_elements=n_elements)

print(f"Result: {c}")

Key Differences from Single-Element Addition

  1. Multiple Threads: Each thread processes one element, enabling parallel computation for vectors of any size.
  2. Grid Size: The grid size matches the number of elements (n_elements) to ensure one thread per element.
  3. Bounds Checking: Added if idx < n_elements to prevent out-of-bounds memory access.

This kernel demonstrates the next step in learning Triton: scaling simple operations to vectorized workloads for parallel execution. 🚀