Softmax Kernel

Here's a rephrased version:

This might not seem like a beginner-friendly kernel, but trust me, it's much simpler than you think. Softmax is your gateway to understanding machine learning kernels.

It converts a vector of values into probabilities, ensuring they sum to 1. Let's dive in and implement it efficiently using Triton!

Softmax Implementation:

Softmax(x_i) = exp(x_i) / sum(exp(x_j))


Kernel Code Breakdown

offsets = row_idx * n_cols + tl.arange(0, n_cols)
mask = tl.arange(0, n_cols) < n_cols
row = tl.load(input_ptr + offsets, mask=mask)
row_max = tl.max(row, axis=0)
row_exp = tl.exp(row - row_max)
row_sum = tl.sum(row_exp, axis=0)
tl.store(output_ptr + offsets, row_exp / row_sum)

Softmax Computation Loop

  • Why Compute row_max First?
    • Softmax requires exponentiation, which can lead to numerical instability. Subtracting row_max ensures that values remain within a stable range.
  • Exponentiation & Summation:
    • tl.exp(row - row_max): Applies exponentiation to stabilize values.
    • tl.sum(row_exp): Computes the denominator for normalization.
  • Final Softmax Calculation:
    • Each element is divided by the row sum to normalize the probabilities.

Complete Kernel Code

Below is the complete kernel along with the host-side code to call it:

import triton
import triton.language as tl

@triton.jit
def softmax_kernel(input_ptr, output_ptr, n_rows: tl.constexpr, n_cols: tl.constexpr):
    row_idx = tl.program_id(0)  # One program instance per row.
    if row_idx < n_rows:
        offsets = row_idx * n_cols + tl.arange(0, n_cols)
        mask = tl.arange(0, n_cols) < n_cols
        row = tl.load(input_ptr + offsets, mask=mask)  # Load row
        row_max = tl.max(row, axis=0)  # Compute row max
        row_exp = tl.exp(row - row_max)  # Apply exponentiation
        row_sum = tl.sum(row_exp, axis=0)  # Compute sum
        tl.store(output_ptr + offsets, row_exp / row_sum)  # Store result

# Host-side code
import torch

n_rows, n_cols = 4, 4  # Dimensions of the matrix
input_matrix = torch.rand(n_rows, n_cols, device="cuda", dtype=torch.float32)  # Input matrix
output_matrix = torch.empty_like(input_matrix)  # Output matrix

grid_size = (n_rows,)  # One thread per row
softmax_kernel[grid_size](input_matrix, output_matrix, n_rows=n_rows, n_cols=n_cols)

print(f"Result:\n{output_matrix}")

🎯 Now you know how to do all the simple stuff, let's jump to the next level! 🚀