Simplest Triton Kernel

This is the simplest Triton kernel you can write: a kernel that adds two single elements. Let's break it down line by line to understand how it works.


Kernel Code Breakdown

@triton.jit

This decorator marks the function as a Triton kernel. Triton will JIT-compile this function into GPU machine code when it is executed.

def add_kernel(a_ptr, b_ptr, c_ptr):

This defines the kernel function. It takes three arguments:

  • a_ptr: A pointer to a single input value.
  • b_ptr: A pointer to another single input value.
  • c_ptr: A pointer to the memory location where the result will be stored.
idx = tl.program_id(0)

tl.program_id(0) fetches the unique thread index along the 0th axis (row-wise) of the grid. For this simplest kernel, only one thread is launched, so idx will always be 0.

a = tl.load(a_ptr + idx)

This loads the single value pointed to by a_ptr into a register for computation. Since we are working with a single element, idx will be 0, so a_ptr + idx is equivalent to a_ptr.

b = tl.load(b_ptr + idx)

Similarly, this loads the single value pointed to by b_ptr into a register.

tl.store(c_ptr + idx, a + b)

This computes the sum of a and b and stores the result into the memory location pointed to by c_ptr. Again, idx is 0, so c_ptr + idx is equivalent to c_ptr.

grid_size = (1,)  # Single thread

This specifies the grid size for the kernel launch. The grid defines the number of threads that will execute the kernel. Here, (1,) means only one thread is launched, making this a single-threaded operation.

add_kernel[grid_size](a, b, c)

This launches the kernel. The grid_size ensures only one thread runs, and the arguments a, b, and c pass the pointers to the input and output values to the kernel.


Complete Kernel Code

Below is the complete kernel along with the host-side code to call it:

import triton
import triton.language as tl

@triton.jit
def add_kernel(a_ptr, b_ptr, c_ptr):
    idx = tl.program_id(0)  # Get the thread index
    a = tl.load(a_ptr + idx)  # Load single input value A
    b = tl.load(b_ptr + idx)  # Load single input value B
    tl.store(c_ptr + idx, a + b)  # Perform addition and store the result

# Host-side code
import torch

a = torch.tensor([2.0], device="cuda")  # Input value A
b = torch.tensor([3.0], device="cuda")  # Input value B
c = torch.empty_like(a)                  # Output value

grid_size = (1,)  # Single thread
add_kernel[grid_size](a, b, c)

print(f"Result: {c[0].item()}")

Congrats, You are now an expert triton programmer!! Just Kidding 😛