
写 CUDA Kernel 写了三年最怕的是什么不是算法难是调grid, block那一行永远写不对。线程索引算错一位debug 一天。Shared Memory bank conflict 搞不明白性能掉一半。等到好不容易跑通了换个 GPU 架构又得重来一遍。后来同事说“试试 Triton 吧一行triton.jit搞定。”我当时是不信的。直到我用 Triton 写完第一个向量加法 Kernel对比 CUDA 版本的代码量直接腰斩——而且性能居然不输手写 CUDA。这篇文章就来复盘 Triton 程序的完整编写流程从 API 到实战把刚上手时最容易踩的坑都给你趟一遍。一、Triton 到底是什么来头Triton 是 OpenAI 搞的开源 GPU 编程语言定位很明确比 CUDA 好写比 PyTorch 灵活。传统 CUDA 编程里你得手动管线程thread、线程束warp、线程块block每个线程该算哪段数据索引写错了就是灾难。Triton 直接把这个模型翻了个个——你写代码时假装只处理一块数据tile编译器负责把这块数据自动拆给几百个线程去并行执行。这叫tile-based programming model基于块的编程模型。核心思路一句话你关心数据块Triton 关心线程。整个编译流程是这样走的triton.jit装饰器捕获你函数的AST抽象语法树不是直接跑 Python 代码AST 被转成 Triton IRMLIR 的自定义方言里面全是 tile 级别的操作Triton IR 进一步降到 TritonGPU IR决定每个 warp 分多少数据、寄存器怎么布局最后走 LLVM 生成 PTXNVIDIA 驱动再转成 SASS 机器码这一套流程最爽的地方同一份 Triton 代码换 GPU 不用改。因为布局优化、warp 分配这些脏活全在编译器里自动完成。二、核心 API 速览这五个东西必须先认全下面这张表是 Triton 编程的身份证不记牢后面写代码会不停翻文档。2.1triton.jit— 一切的起点triton.jitdefmy_kernel(x_ptr,y_ptr,output_ptr,n_elements,BLOCK_SIZE:tl.constexpr):...关键点这不是普通的 Python 装饰器——它不会执行你写的代码而是把函数体抓成 AST 丢给编译器。带tl.constexpr标记的参数是编译期常量。同一个 kernel 用不同BLOCK_SIZE调用编译器会分别生成两份优化过的机器码——这叫特化specialization是 Triton 性能不输 CUDA 的核心原因之一。Kernel 函数里不能随便写 Python只能用tl.load/tl.store/tl.arange这些 Triton DSL 操作。2.2triton.autotune— 让 GPU 自己挑参数triton.autotune(configs[triton.Config(kwargs{BLOCK_SIZE:128},num_warps4),triton.Config(kwargs{BLOCK_SIZE:1024},num_warps8),],key[x_size])triton.jitdefkernel(x_ptr,x_size,**META):BLOCK_SIZEMETA[BLOCK_SIZE]这是 Triton 最让我惊艳的功能。你不用猜BLOCK_SIZE设 128 还是 1024 性能好——把候选配置丢进去Triton 会逐个编译并跑一遍自动选出最优的那个。几个要注意的坑key参数是用来分组缓存的。如果key[x_size]当x_size变化时才会重新评估所有配置。设计 key 的时候只放会影响性能选择的参数别把什么都塞进去否则 autotune 开销爆炸。autotune 会把 kernel 跑很多遍如果你 kernel 里会修改全局状态比如累加计数必须用reset_to_zero参数指定哪些 tensor 每次跑前归零。第一次调用时 autotune 有预热开销后面命中缓存就快了。2.3triton.Config— 四个参数决定生死一个 Config 对象就是一份内核配置方案autotune 会逐个尝试。四个核心参数参数含义调优建议num_warps每个 block 分配的 warp 数1 warp 32 线程VI00 用 2-4A100 用 4-8H100 可上 8-16num_stages异步数据预取的流水线深度计算密集型 2-3访存密集型如 MatMul3-5num_ctasblock cluster 中的 block 数SM90 专属H100 才需要关注maxnreg单线程最大寄存器数寄存器溢出时调这个不是所有平台都支持最重要的交互num_warps和num_stages会抢同一块共享内存shared memory。warps 越多 → 线程越多 → 每个线程分到的寄存器越少 → 可能触发寄存器溢出register spilling。stages 越多 → 预取缓存越大 → 占的 shared memory 越多。加一个就得考虑减另一个别两个一起拉满。2.4 Math Ops — 这些算子直接能用算子说明tl.abs(x)逐元素绝对值tl.cdiv(x, div)向上取整除法算 grid 大小必用tl.sqrt(x)快速平方根硬件近似比math.sqrt快但精度略低tl.softmax(x)Softmax注意是整块计算不要自己手写tl.cos(x)/tl.sin(x)三角函数cdiv是最常用的——因为你要根据n_elements和BLOCK_SIZE算出需要多少个 block公式就是triton.cdiv(n_elements, BLOCK_SIZE)。2.5 Debug Ops — GPU 上的 printfCUDA 调试痛苦的原因之一kernel 里打不了断点只能靠printf。Triton 把 debug 分了两层算子阶段用途tl.static_print(...)编译期打印编译时常量如BLOCK_SIZEtl.static_assert(cond)编译期编译时断言如检查BLOCK_SIZE是 2 的幂tl.device_print(...)运行期GPU 上实时打印变量值tl.device_assert(cond)运行期运行时断言如检查mask范围static_print和static_assert非常实用——它们不会产生任何 GPU 指令只在 JIT 编译时执行零性能开销。三、实战用 Triton 写向量加法光看 API 没用直接上代码。3.1 Kernel 函数triton.jitdefadd_kernel(x_ptr,y_ptr,output_ptr,n_elements,BLOCK_SIZE:tl.constexpr):# Step 1: 我是第几个 blockpidtl.program_id(axis0)# Step 2: 这个 block 负责的数据起始位置block_startpid*BLOCK_SIZE# Step 3: 生成这个 block 里的所有偏移量 [0, 1, 2, ..., BLOCK_SIZE-1]offsetsblock_starttl.arange(0,BLOCK_SIZE)# Step 4: 最后一个 block 可能越界做 maskmaskoffsetsn_elements# Step 5: 从全局内存加载xtl.load(x_ptroffsets,maskmask)ytl.load(y_ptroffsets,maskmask)# Step 6: 算outputxy# Step 7: 写回全局内存tl.store(output_ptroffsets,output,maskmask)这里解释几个新人容易懵的点tl.program_id(axis0)Triton 里没有blockIdx.x这种 CUDA 概念直接用program_id获取我这个 block 是第几个。axis0 就是一维 gridaxis1 / axis2 对应二维 / 三维。tl.arange(0, BLOCK_SIZE)生成一个从 0 到 BLOCK_SIZE-1 的向量。注意这不是 Python 的 range而是一个 GPU 上的向量后续所有操作都是按这个向量并行展开的。maskoffsets n_elements数据总长度不一定是 BLOCK_SIZE 的整数倍最后一个 block 会多算一些位置。mask 确保这些越界的偏移量不会被真的读写——tl.load和tl.store遇到 maskFalse 的位置会直接跳过。指针运算x_ptr offsetsTriton 里指针是整型直接加偏移量就行不需要x_ptr[offsets]这种语法。3.2 封装调用函数defadd(x:torch.Tensor,y:torch.Tensor):# 分配输出 tensoroutputtorch.empty_like(x)# 安全检查数据必须在 GPU 上assertx.is_cudaandy.is_cudaandoutput.is_cuda n_elementsoutput.numel()# 计算 grid需要多少个 blockgridlambdameta:(triton.cdiv(n_elements,meta[BLOCK_SIZE]),)# 启动 kerneladd_kernel[grid](x,y,output,n_elements,BLOCK_SIZE1024)returnoutput最需要解释的是grid lambda meta: ...这个写法meta是一个字典包含BLOCK_SIZE等编译期常量。这里meta[BLOCK_SIZE]就是 1024。返回值是一个元组(grid_x, grid_y, grid_z)这里只有一维所以是单元素元组。add_kernel[grid]这种调用语法类似 CUDA 的grid, block只不过 Triton 的block 大小已经在BLOCK_SIZE: tl.constexpr里定义好了这里只指定 grid。3.3 运行结果$ python 01-vector-add.py输出显示 Triton 计算结果与 PyTorch 原生算子的最大差异为0.0——完全一致。性能对比那块更有意思从 4096 个元素一路测到 1.34 亿个元素Triton 版本和 PyTorch底层也是 CUDA的耗时几乎完全重叠差距在 1% 以内。这说明用 Triton 写的向量加法编译出来的机器码质量不输 PyTorch 高度优化的 CUDA kernel。四、踩坑记录我在 Triton 上栽过的跟头写几个自己实际遇到、PPT 里不会直接说的坑坑 1BLOCK_SIZE不是越大越好直觉上 block 越大并行度越高但 block 太大会导致① 寄存器不够用触发 spilling性能反而暴跌② shared memory 不够用如果你的 kernel 用了。向量加法这种极简单 kernel1024 是个不错的默认值复杂 kernel 如矩阵乘法每维 64-128 更常见。坑 2mask 没写对静默出 bugtl.load的mask参数如果不传越界的地址会读到未定义值——GPU 上不会直接 crash但算出来的结果可能完全对也可能偶尔错特别难排查。任何带offsets的load/store都要检查边界。坑 3autotune 第一次跑很慢autotune 会逐配置编译运行候选配置多的话第一次调用可能要等几十秒甚至几分钟。这正常因为 Triton 在 JIT 编译。第二次调用命中缓存就秒开了。生产环境建议提前 warmup。坑 4num_stages不是越大越好num_stages增加异步预取的流水线深度能隐藏访存延迟但每多一级 stage 就多占一块 shared memory。如果你的 kernel 本身 shared memory 用量就高比如矩阵乘法里的大块 tile再加 stages 会爆 shared memory 容量编译直接失败。五、小结用 Triton 写 GPU 程序的体验打个不恰当的比方CUDA 像手动挡每个换挡时机都得自己把握Triton 像自动挡 运动模式把最烦的线程调度交给编译器但关键参数BLOCK_SIZE、num_warps、num_stages你仍然能调。回到开头那个向量加法——从 CUDA 迁移到 Triton代码量减半性能持平而且换个 GPU 不用改一行代码。对于大部分我需要一个自定义 kernel但不想为线程索引掉头发的场景Triton 是目前最好的选择。本文基于杜玉博老师《Triton程序编写》PPT 整理图片均为原 PPT 截图。代码示例可在 Triton 官方仓库 找到完整教程。