基本用法#
要使用 FlagGems 算子库中的算子,可以在运行你的程序之前导入 flag_gems 并启用加速。
你可以在全局启用 flag_gems,也可以针对特定的代码段启用 flag_gems。
除此之外,你还可以直接调用 flag_gems.ops 包中的指定算子。
1. 全局启用#
如果希望在你的整个程序中或者整个交互会话期间启用 FlagGems 算子,可以执行下面的语句:
import flag_gems
# 全局性地启用 FlagGems 算子
flag_gems.enable()一旦启用,你的代码中的所有被支持的算子都会自动替换为 FlagGems 中优化过的实现,
除此之外无需其他修改。这意味着 FlagGems 所支持的 torch.* / torch.nn.functional.*
调用都会被自动派发到加速版本的实现。例如:
import torch
import flag_gems
flag_gems.enable()
x = torch.randn(4096, 4096, device=flag_gems.device, dtype=torch.float16)
y = torch.mm(x, x)2. 指定作用域的启用#
在必要的时候,你可以使用 with... 语句仅针对指定的代码块启用 FlagGems:
import flag_gems
import torch
# 针对特定操作启用 flag_gems
with flag_gems.use_gems():
# 这段代码会使用 FlagGems 中被加速的算子
x = torch.randn(4096, 4096, device=flag_gems.device, dtype=torch.float16)
y = torch.mm(x, x)限定作用域的用法在以下场景比较有用:
- 对算子作性能基准测试,或者
- 比较不同实现之间的精度,或者
- 在复杂的工作流中有选择地应用加速算子
3. 直接调用#
你也可以略过 PyTorch 中的派发过程,直接调用 flag_gems.ops 包中的算子:
import torch
from flag_gems import ops
import flag_gems
a = torch.randn(1024, 1024, device=flag_gems.device, dtype=torch.float16)
b = torch.randn(1024, 1024, device=flag_gems.device, dtype=torch.float16)
c = ops.mm(a, b)4. 查询已注册的算子#
在启用了 FlagGems 之后,你可以检查系统中已经注册的算子列表:
import flag_gems
flag_gems.enable()
# 获取已注册的算子函数名
registered_funcs = flag_gems.all_registered_ops()
print("Registered functions:", registered_funcs)
# 获取已注册算子的主键
registered_keys = flag_gems.all_registered_keys()
print("Registered keys:", registered_keys)这一 API 有助于调试或者检查哪些算子在起作用。
5. 进阶用法#
你可以阅读以下相关文档,了解一些高级的使用场景: