Softmax Kernel
Here's a rephrased version:
This might not seem like a beginner-friendly kernel, but trust me, it's much simpler than you think. Softmax is your gateway to understanding machine learning kernels.
It converts a vector of values into probabilities, ensuring they sum to 1. Let's dive in and implement it efficiently using Triton!
Softmax Implementation:
Softmax(x_i) = exp(x_i) / sum(exp(x_j))
Kernel Code Breakdown
offsets = row_idx * n_cols + tl.arange(0, n_cols)
mask = tl.arange(0, n_cols) < n_cols
row = tl.load(input_ptr + offsets, mask=mask)
row_max = tl.max(row, axis=0)
row_exp = tl.exp(row - row_max)
row_sum = tl.sum(row_exp, axis=0)
tl.store(output_ptr + offsets, row_exp / row_sum)
Softmax Computation Loop
- Why Compute
row_max
First?- Softmax requires exponentiation, which can lead to numerical instability. Subtracting
row_max
ensures that values remain within a stable range.
- Softmax requires exponentiation, which can lead to numerical instability. Subtracting
- Exponentiation & Summation:
tl.exp(row - row_max)
: Applies exponentiation to stabilize values.tl.sum(row_exp)
: Computes the denominator for normalization.
- Final Softmax Calculation:
- Each element is divided by the row sum to normalize the probabilities.
Complete Kernel Code
Below is the complete kernel along with the host-side code to call it:
import triton
import triton.language as tl
@triton.jit
def softmax_kernel(input_ptr, output_ptr, n_rows: tl.constexpr, n_cols: tl.constexpr):
row_idx = tl.program_id(0) # One program instance per row.
if row_idx < n_rows:
offsets = row_idx * n_cols + tl.arange(0, n_cols)
mask = tl.arange(0, n_cols) < n_cols
row = tl.load(input_ptr + offsets, mask=mask) # Load row
row_max = tl.max(row, axis=0) # Compute row max
row_exp = tl.exp(row - row_max) # Apply exponentiation
row_sum = tl.sum(row_exp, axis=0) # Compute sum
tl.store(output_ptr + offsets, row_exp / row_sum) # Store result
# Host-side code
import torch
n_rows, n_cols = 4, 4 # Dimensions of the matrix
input_matrix = torch.rand(n_rows, n_cols, device="cuda", dtype=torch.float32) # Input matrix
output_matrix = torch.empty_like(input_matrix) # Output matrix
grid_size = (n_rows,) # One thread per row
softmax_kernel[grid_size](input_matrix, output_matrix, n_rows=n_rows, n_cols=n_cols)
print(f"Result:\n{output_matrix}")
🎯 Now you know how to do all the simple stuff, let's jump to the next level! 🚀