开始使用 FlagGems#
1. 安装 FlagGems#
FlagGems 可以以纯 Python 包的形式安装,也可以附带 C++ 扩展特性来安装, 以实现更好的运行时性能。你可以参阅安装指南 的说明,了解不同安装方式、选项的详细指令。
2. 检查安装状态#
在安装了 flag_gems 包及其所依赖的其他软件包之后,
你可能希望检查这些软件包是否能够正常工作。
2.1 检查 PyTorch 环境#
你要执行的第一项检查是确认能够在你的工作环境中导入 torch 软件包:
python -c "import torch; print(torch.__version__)"如果你所使用的是一个非 NVIDIA 的平台, 你可能需要安装一个由后端硬件厂商所提供的 PyTorch 插件。 你也可以针对这一插件执行类似的验证操作。 例如,在一个使用摩尔线程 GPU 的平台上,你可以使用下面的命令来检查 是否插件能够正常工作:
python -c "import torch_musa; print(torch_musa.__version__)"2.2 检查 Triton 安装#
接下来的一项检查是针对 triton 软件包的。
你可以使用下面的命令来检查 triton 包是否可以正常使用。
python -c "import torch, triton; print(triton.__version__)"如果你所使用的是一个非 NVIDIA 的平台, 你可能需要咨询的平台供应商,了解他们是否提供了一个定制版本。
警告
通常,厂商定制的
triton软件包与上游社区获得的软件包同名。 在进入下一步之前,你需要反复确认自己所使用的包是正确的版本。
2.3 检查 FlagGems 的安装#
针对 flag_gems 包,你也可以使用下面的命令执行类似的检查:
python -c "import flag_gems; print(flag_gems.__version__)"3. 开始使用 FlagGems#
你可以用多种不同方式来启用 flag_gems 所提供的加速算子。
下面的代码段在全局范围内启用 flag_gems:
import flag_gems
flag_gems.enable()
# 你自己的代码 ...你也可以针对自己代码中的某一部分,在特定的上下文中启用 flag_gems,
如下例所示:
# 你的代码 ...
# 在特定上下文中启用 flag_gems
with flag_gems.use_gems():
# 在上下文中使用加速的算子
# ...例如:
import torch
import flag_gems
M, N, K = 1024, 1024, 1024
A = torch.randn((M, K), dtype=torch.float16, device=flag_gems.device)
B = torch.randn((K, N), dtype=torch.float16, device=flag_gems.device)
with flag_gems.use_gems():
C = torch.mm(A, B)你可以查阅使用指南一节中的详细文档, 了解 FlagGems 的多种使用模式。