Enabling FlagGems#

To use the operators from the FlagGems operator library, import flag_gems and enable acceleration before running your program. You can enable it globally or for a specific code block. Besides these, you can invoke operators from the flag_gems.ops package directly.

1. Global Enablement#

To apply FlagGems optimizations across your entire program or your interaction session:

import flag_gems

# Enable FlagGems operators globally
flag_gems.enable()

Once enabled, all supported operators in your code will be replaced automatically by the optimized FlagGems implementations — no further changes needed. This means the supported torch.* / torch.nn.functional.* calls will be dispatched to FlagGems implementations automatically. For example:

import torch
import flag_gems

flag_gems.enable()

x = torch.randn(4096, 4096, device=flag_gems.device, dtype=torch.float16)
y = torch.mm(x, x)

2. Scoped Enablement#

When needed, you can enable FlagGems only within a specific code block using a with... statement:

import flag_gems
import torch

# Enable flag_gems for specific operations
with flag_gems.use_gems():

    # Code inside this block will use FlagGems-accelerated operators
    x = torch.randn(4096, 4096, device=flag_gems.device, dtype=torch.float16)
    y = torch.mm(x, x)

This scoped usage is useful when you want to:

  • perform performance benchmarks, or
  • compare correctness between implementations, or
  • apply acceleration selectively in complex workflows.

3. Direct invocation#

You can bypass the PyTorch dispatch process and directly invoke operators from the flag_gems.ops package.

import torch
from flag_gems import ops
import flag_gems

a = torch.randn(1024, 1024, device=flag_gems.device, dtype=torch.float16)
b = torch.randn(1024, 1024, device=flag_gems.device, dtype=torch.float16)
c = ops.mm(a, b)

4. Query Registered Operators#

After having enabled FlagGems, you can check the operators registered:

import flag_gems

flag_gems.enable()

# Get list of registered function names
registered_funcs = flag_gems.all_registered_ops()
print("Registered functions:", registered_funcs)

# Get list of registered operator keys
registered_keys = flag_gems.all_registered_keys()
print("Registered keys:", registered_keys)

This is useful for debugging or verifying which operators are active.

5. Advanced Usage#

For advanced usage scenarios, check the following related documentation: