使用 C++ 封装的算子获得更好的性能#

使用 FlagGems 时的另一条高级的优化路径是针对所选的操作使用其中的C++ 封装的算子。 尽管 Triton 内核通常能够给出相当不错的计算性能,Triton 本身是使用 Python 实现的 DSL。 这意味着算子的定义以及算子的运行时派发机制都是用 Python 编写的, 因此在延迟非常敏感或者对吞吐要求极为苛刻的场景下会存在不可忽视的性能开销

为了解决这一问题,FlagGems 提供了一套 C++ 运行时解决方案,用 C++ 语言来实现 算子的封装逻辑、注册机制和运行时管理,与此同时仍然复用下层的 Triton 内核来完成实际计算。 这种方法能够保留 Triton 中内核级别的效率,同时大幅降低 Python 语言相关的性能开销, 使得用户能够与底层的 CUDA 工作流进行更为紧密的集成,提升整体的推理性能。

1. 安装#

要使用 C++ 算子封装:

  1. 遵从安装指南中的指令编译、安装带有 C++ 扩展特性的 flag_gems 包。

  2. 使用下面的代码段来验证安装是否成功:

    try:
        from flag_gems import c_operators
        has_c_extension = True
    except Exception as e:
        c_operators = None  # 避免在 c_operators 不可用时出现 import 错误
        has_c_extension = False

    如果 has_c_extensionTrue,则 C++ 运行时所支持的执行路径是可用的。

  3. 安装成功之后,C++ 封装的算子在补丁模式下具有更高的优先级。 当显式使用 FlagGems 所提供的模块来构建模型时, C++ 封装的算子也比其对应的 Python 等价实现的优先级更高。

    例如,算子 gems_rms_forward 默认会使用 C++ 封装版本的 rms_norm。 你可以参考源码库中的 normalization.py 文件,更好地了解如何集成 C++ 封装的算子以及如何调用它们。

2. 显式调用 C++ 算子#

如果你希望直接调用 C++ 封装的算子,略过打补丁逻辑或者其他回退执行路径, 可以按下面的代码所给的那样,使用 torch.ops.flag_gems 名字空间:

output = torch.ops.flag_gems.fused_add_rms_norm(...)

这种方式能够让你对算子的派发进行更为精确的控制,在某些性能很关键的语境中可能很有用。

参考:目前支持的 C++ 封装的算子#

算子名称描述
add逐元素的加法
bmm批量的矩阵乘法
cat串接
fused_add_rms_norm加法与 RMSNorm 的融合
mm矩阵乘法
nonzero返回非零元素的索引
rms_norm均方根归一化
rotary_embedding旋转位置编码
sum跨维度的降维(规约)

作为持续性能优化工作的一部分,我们一直在努力扩大这一列表。