在非英伟达(NVIDIA)硬件上使用 FlagGems#

1. 支持的平台#

FlagGems 在 NVIDIA 芯片之外支持若干不同类型的 AI 芯片或平台。 请参阅平台支持 文档了解已经验证过的平台的最新列表。

2. 统一的使用接口#

无论下层使用的是哪种硬件,flag_gems 的用法始终保持不变。 从 NVIDIA 平台切换到非 NVIDIA 平台时,一般而言不需要更改应用代码。

当你在代码中导入了 flag_gems 包,并且 启用了 FlagGems 加速之后, 算子的派发机制会自动将对算子的调用指向针对当前后端的正确实现上。 这一派发机制为开发者提供了跨不同环境的体验一致性。

3. 平台需求#

尽管 FlagGems 的使用模式保持不变,在非 NVIDIA 平台上运行 FlagGems 时仍然需要满足一些前置条件。你必须在目标平台上安装了 PyTorchTriton 编译器。

通常而言,有两种方式可以实现兼容的软件构建:

  1. 咨询你的硬件厂商

    某些平台需要一些额外的安装配置或者补丁操作。 硬件厂商通常会针对自己的芯片开发维护 PyTorchTriton 的定制构建版本。 你需要与厂商取得联系,以获得这些软件包的正确版本。

  1. 尝试 FlagTree 项目

    FlagTree 项目提供一种统一的 Triton 编译器, 能够支持多种不同的 AI 芯片,包括 NVIDIA 和非 NVIDIA 平台。 FlagTree 能够将特定于厂商的补丁和增强聚合在一起,形成一个共享的、开源的编译器后端, 从而简化编译器的维护工作,保证跨多个平台的兼容性。

    需要注意的是 FlagTree 仅提供在 Triton 语言层的编译器框架。 你仍然需要安装部署一个合适的 PyTorch 发行版本。

4. 后端自动检测与手动设置#

默认情况下,FlagGems 能够在运行时自动检测当前使用的硬件后端,为之选择对应的算子实现。 很多时候,所有的组件都能够直接工作,不需要手动的配置。

如果内置的自动后端检测机制失败,或者在你的环境中出现了兼容性相关的问题, 你可以手动设置目标后端,以确保运行时算子的行为是正确的。 你可以通过设置下面的环境变量来指定后端,之后再运行你的代码:

export GEMS_VENDOR=<厂商名称>

参阅平台支持文档, 了解不同厂商对应的符号名。

警告

手动指定的后端名称要与实际的硬件平台匹配。 手动设置一个错误的后端符号名可能会导致软件运行错误。

你可以通过在运行时查看 flag_gems.vendor_name 的取值来检查当前使用的后端。

import flag_gems
print(flag_gems.vendor_name)