Get Started With FlagGems#

1. Install FlagGems#

FlagGems can be installed either as a pure python package or a package with C-extensions for better runtime performance. See installation for some detailed guidance on different installation options.

2. Verify the installation#

After having installed the flag_gems package and its dependencies, you may want to verify if they work as expected.

2.1 Verify PyTorch environment#

The first thing you want to check is that you can import torch in your working environment:

python -c "import torch; print(torch.__version__)"

If you are using a non-NVIDIA platform, you may have a PyTorch plugin that is provided by the backend vendor. You can perform a similar verification against this plugin. For example, on a MooreThreads GPU platform, you can verify if the plugin works using the following command:

python -c "import torch_musa; print(torch_musa.__version__)"

2.2 Verify the Triton setup#

The next verification is against the triton package. You can check if triton is working as expected using the following command:

python -c "import torch, triton; print(triton.__version__)"

Note that if you are using a non-NVIDIA platform, you have to consult your platform vendor to see if they have a customized version.

Warning

Usually, the vendor-customized version of triton has the same name with the upstream package. Please double confirm that you are using the correct package before proceeding.

2.3 Verify the FlagGems installation#

You can do a similar verification for the flag_gems package, using the following command:

python -c "import flag_gems; print(flag_gems.__version__)"

3. Start using FlagGems#

You can enable the accelerated operators from flag_gems in many ways. The following code snippet enables flag_gems globally:

import flag_gems

flag_gems.enable()

# your code goes here

You can also enable flag_gems in a specific context for a certain section of your code, as shown below:

# your code goes here ..

# enable flag_gems in a specific context
with flag_gems.use_gems():
    # use the accelerated operators here
    # ...

For example:

import torch
import flag_gems

M, N, K = 1024, 1024, 1024
A = torch.randn((M, K), dtype=torch.float16, device=flag_gems.device)
B = torch.randn((K, N), dtype=torch.float16, device=flag_gems.device)
with flag_gems.use_gems():
    C = torch.mm(A, B)

Check the using FlagGems section for more detailed documentation on the various usage patterns about FlagGems.