Vector Addition
This kernel extends the simplest Triton kernel to perform element-wise vector addition. Instead of adding two single elements, it operates on multiple pairs of elements in parallel using multiple threads. Let’s break it down and highlight the differences.
Kernel Code Breakdown
def vector_add_kernel(a_ptr, b_ptr, c_ptr, n_elements: tl.constexpr):
This defines the kernel function. It takes four arguments:
a_ptr
: A pointer to the first input vector.b_ptr
: A pointer to the second input vector.c_ptr
: A pointer to the output vector.n_elements
: The number of elements to process, passed as a constant.
idx = tl.program_id(0)
tl.program_id(0)
retrieves the unique thread index for the current thread along the 0th axis of the grid. Each thread processes a unique element based on this index.
if idx < n_elements:
This ensures that threads with an index greater than the number of elements do not perform out-of-bounds memory operations.
a = tl.load(a_ptr + idx)
Loads the idx
-th element of the vector pointed to by a_ptr
into a register for computation.
b = tl.load(b_ptr + idx)
Similarly, loads the idx
-th element of the vector pointed to by b_ptr
.
tl.store(c_ptr + idx, a + b)
Adds the loaded values a
and b
and stores the result in the idx
-th position of the output vector.
Complete Kernel Code
Below is the complete kernel along with the host-side code to call it:
import triton
import triton.language as tl
@triton.jit
def vector_add_kernel(a_ptr, b_ptr, c_ptr, n_elements: tl.constexpr):
idx = tl.program_id(0) # Get the thread index
if idx < n_elements: # Bounds check
a = tl.load(a_ptr + idx) # Load element from input vector A
b = tl.load(b_ptr + idx) # Load element from input vector B
tl.store(c_ptr + idx, a + b) # Perform addition and store the result
# Host-side code
import torch
n_elements = 8 # Number of elements in the vectors
a = torch.arange(1, n_elements + 1, device="cuda", dtype=torch.float32) # Vector A: [1, 2, ..., 8]
b = torch.arange(1, n_elements + 1, device="cuda", dtype=torch.float32) # Vector B: [1, 2, ..., 8]
c = torch.empty_like(a) # Output vector
grid_size = (n_elements,) # One thread per element
vector_add_kernel[grid_size](a.data_ptr(), b.data_ptr(), c.data_ptr(), n_elements=n_elements)
print(f"Result: {c}")
Key Differences from Single-Element Addition
- Multiple Threads: Each thread processes one element, enabling parallel computation for vectors of any size.
- Grid Size: The grid size matches the number of elements (
n_elements
) to ensure one thread per element. - Bounds Checking: Added
if idx < n_elements
to prevent out-of-bounds memory access.
This kernel demonstrates the next step in learning Triton: scaling simple operations to vectorized workloads for parallel execution. 🚀