Simplest Triton Kernel
This is the simplest Triton kernel you can write: a kernel that adds two single elements. Let's break it down line by line to understand how it works.
Kernel Code Breakdown
@triton.jit
This decorator marks the function as a Triton kernel. Triton will JIT-compile this function into GPU machine code when it is executed.
def add_kernel(a_ptr, b_ptr, c_ptr):
This defines the kernel function. It takes three arguments:
a_ptr
: A pointer to a single input value.b_ptr
: A pointer to another single input value.c_ptr
: A pointer to the memory location where the result will be stored.
idx = tl.program_id(0)
tl.program_id(0)
fetches the unique thread index along the 0th axis (row-wise) of the grid. For this simplest kernel, only one thread is launched, so idx
will always be 0
.
a = tl.load(a_ptr + idx)
This loads the single value pointed to by a_ptr
into a register for computation. Since we are working with a single element, idx
will be 0
, so a_ptr + idx
is equivalent to a_ptr
.
b = tl.load(b_ptr + idx)
Similarly, this loads the single value pointed to by b_ptr
into a register.
tl.store(c_ptr + idx, a + b)
This computes the sum of a
and b
and stores the result into the memory location pointed to by c_ptr
. Again, idx
is 0
, so c_ptr + idx
is equivalent to c_ptr
.
grid_size = (1,) # Single thread
This specifies the grid size for the kernel launch. The grid defines the number of threads that will execute the kernel. Here, (1,)
means only one thread is launched, making this a single-threaded operation.
add_kernel[grid_size](a, b, c)
This launches the kernel. The grid_size
ensures only one thread runs, and the arguments a
, b
, and c
pass the pointers to the input and output values to the kernel.
Complete Kernel Code
Below is the complete kernel along with the host-side code to call it:
import triton
import triton.language as tl
@triton.jit
def add_kernel(a_ptr, b_ptr, c_ptr):
idx = tl.program_id(0) # Get the thread index
a = tl.load(a_ptr + idx) # Load single input value A
b = tl.load(b_ptr + idx) # Load single input value B
tl.store(c_ptr + idx, a + b) # Perform addition and store the result
# Host-side code
import torch
a = torch.tensor([2.0], device="cuda") # Input value A
b = torch.tensor([3.0], device="cuda") # Input value B
c = torch.empty_like(a) # Output value
grid_size = (1,) # Single thread
add_kernel[grid_size](a, b, c)
print(f"Result: {c[0].item()}")
Congrats, You are now an expert triton programmer!! Just Kidding 😛