Using C++-Based Operators for Optimal Performance#
Another advanced optimization path with FlagGems is the use of C++ wrapped operators for selected operations. While Triton kernels offer reasonably good compute performance, Triton itself is a DSL implemented in Python. This means that both the operator definitions and the runtime dispatchers are written in Python, which can introduce non-trivial overhead in latency-sensitive or high-throughput scenarios.
To address this, FlagGems provides a C++ runtime solution that encapsulates the operator's wrapper logic, registration mechanism, and runtime management in C++, while still reusing the underlying Triton kernels for the actual computation. This approach preserves the kernel-level efficiency from Triton while significantly reducing Python-related overhead, enabling tighter integration with low-level CUDA workflows and improving overall inference performance.
1. Architecture#
The C++ wrapped operators in FlagGems are built on top of
libtriton_jit, a multi-backend
C++ runtime for Triton JIT functions. libtriton_jit reimplements the Triton
JIT runtime in C++ (argument specialization, kernel caching, and launch) while
delegating the actual compilation to the upstream Triton compiler.
In this stack:
- The Triton kernels (
*.py) remain the source of truth for device-side computation. libtriton_jithandles JIT specialization, kernel caching, and backend-specific launches (currently supporting NVIDIA (CUDA), Moore Threads (MUSA), Huawei Ascend (NPU) and Iluvatar CoreX (IX)).- FlagGems's C++ wrappers (under
lib/, e.g.rms_norm.cpp,mm.cpp) implement tensor metadata handling, shape/type promotion, and argument preparation in C++, then invoke the Triton kernels throughlibtriton_jit::TritonJITFunction. - On top of the wrappers, FlagGems ships two Python-facing extension modules
(
src/flag_gems/csrc/cstub.cppandsrc/flag_gems/csrc/aten_patch.cpp) and one installable C++ library target (FlagGems::operators), which together expose the wrappers through four different invocation paths (see §3. Ways to invoke C++ operators).
Regardless of which invocation path is used, the wrapper logic itself is always executed in C++ (tensor metadata handling, argument type and specialization analysis, kernel cache lookup, and launch-argument preparation) instead of in Python — that part of the Python overhead is eliminated unconditionally, while the compute path continues to use the same Triton kernels.
Whether the PyTorch dispatcher overhead is also avoided depends on the path you pick:
- Paths that go through the dispatcher (
torch.ops.flag_gems.*and the ATen direct replacement) still pay the usual dispatcher cost, but since the op implementation sitting behind the dispatcher is C++ rather than a Python wrapper, the boxed-call overhead is still noticeably smaller than for a pure-Python custom op. - Paths that bypass the dispatcher (the
c_operatorspybind module and the native C++ API) remove the dispatcher cost entirely; the native C++ API additionally removes any Python-interpreter involvement on the call path.
See §3. Ways to invoke C++ operators for the trade-offs of each path.
2. Install and enable#
To make the C++ wrapper fully effective you need both of the following:
At build/install time: enable the C++ extension and build in Release mode
Install from source with at least
-DFLAGGEMS_BUILD_C_EXTENSIONS=ONand-DCMAKE_BUILD_TYPE=Release(the latter ensures both FlagGems itself and thelibtriton_jitsubproject built alongside it are compiled with platform-targeted optimizations; without it the wrapper will be noticeably slower):CMAKE_ARGS="-DFLAGGEMS_BUILD_C_EXTENSIONS=ON -DCMAKE_BUILD_TYPE=Release" \ pip install -v -e .If the command above fails, try adding
--no-build-isolationso that pip reuses the PyTorch already installed in your environment and the build dependencies fromrequirements_<backend>.txt.Other useful options:
-DFLAGGEMS_BACKEND=<CUDA|IX|MUSA|NPU>: select the target backend (defaultCUDA);-DFLAGGEMS_BUILD_POINTWISE_DYNAMIC_CPP=ON: build the pointwise-dynamic operators (add,div,fill);-DFLAGGEMS_BUILD_CTESTS=ON: build thectests/GTest suite (the only way to verify the native C++ API in §3.4);-DFLAGGEMS_USE_EXTERNAL_TRITON_JIT=ON -DTritonJIT_ROOT=<path>: build against an externally installedlibtriton_jit.
See the install guide for the complete per-backend examples and
libtriton_jitdetails.At runtime:
export USE_C_EXTENSION=1Building the C++ extension alone is not enough.
src/flag_gems/config.pygates several higher-level behaviors behind this env var — if you don't set it, the following paths silently fall back to Python:Path / behavior Available after build Also needs USE_C_EXTENSION=1§3.1 torch.ops.flag_gems.*✅ — §3.3 c_operatorspybind✅ — §3.2 ATen direct replacement ( aten_patch)❌ ✅ C++ branch in flag_gems.enable()❌ ✅ C++ branch in GemsRMSNormand othernn.Modules❌ ✅ So for normal use:
export USE_C_EXTENSION=1Quick sanity check
The following snippet verifies, in one go, all three paths that are observable from Python:
import torch import flag_gems from flag_gems import c_operators, aten_patch from flag_gems.config import has_c_extension, use_c_extension assert has_c_extension, "C++ extension was not built" assert use_c_extension, "please `export USE_C_EXTENSION=1`" assert hasattr(torch.ops.flag_gems, "mm"), "§3.1 torch.ops.flag_gems.* not registered" assert aten_patch.get_registered_ops(), "§3.2 no ATen op has been replaced" _ = c_operators.mm # §3.3The §3.4 native C++ API is not observable from Python. To verify it, build with
-DFLAGGEMS_BUILD_CTESTS=ONand runctest:BUILD_DIR=$(ls -d build/*/ | head -n 1) ctest --test-dir "${BUILD_DIR}" --output-on-failureWhen running a single test binary manually (e.g.
"${BUILD_DIR}/ctests/test_triton_mm"), you mustexport FLAGGEMS_SOURCE_DIR=$(pwd)/src/flag_gemsso the C++ runtime can locate the Triton kernel.pyfiles;ctestsets this automatically.Typical usage scenarios
With the two steps above in place, the following two usage styles will automatically prefer the C++ wrapped operators — you don't need to change any call sites:
- Patch mode (
flag_gems.enable()): monkey-patchestorch.*entry points. Whenuse_c_extensionisTrue, the patched functions dispatch totorch.ops.flag_gems.*(§3.1); otherwise they fall back to the pure-Python implementation. - Building models with the
nn.Moduleclasses FlagGems ships, e.g.flag_gems.modules.GemsRMSNorm. These modules already contain the "if C++ is available → calltorch.ops.flag_gems.*, otherwise call the Python implementation" branch internally. Seegems_rms_forwardfor a concrete example.
- Patch mode (
3. Ways to invoke C++ operators#
Once the C++ extensions are built, the same underlying C++ wrapper can be invoked through four different paths. Each path targets a different use case and has a different level of dispatcher overhead.
3.1 Via torch.ops.flag_gems.* (custom-op namespace)#
All C++ wrappers are registered as PyTorch custom ops under the
flag_gems namespace via TORCH_LIBRARY(flag_gems, m) in
src/flag_gems/csrc/cstub.cpp. You can call them explicitly from Python,
bypassing any patching logic or Python-side fall back paths:
output = torch.ops.flag_gems.fused_add_rms_norm(...)
out = torch.ops.flag_gems.mm(a, b)3.2 Via ATen direct replacement (transparent torch.* patching at the dispatcher)#
For a subset of operators, FlagGems additionally registers the C++
implementations directly under the aten:: namespace using
TORCH_LIBRARY_IMPL(aten, <dispatch_key>, m) in
src/flag_gems/csrc/aten_patch.cpp. The dispatch key is chosen by backend:
CUDAfor NVIDIA CUDA and Iluvatar CoreX (IX);PrivateUse1for Huawei Ascend (NPU) and Moore Threads (MUSA).
Because the registration goes straight into the PyTorch dispatcher, calling
standard PyTorch APIs such as torch.nonzero(x) or x.copy_(y) on a
supported device will transparently dispatch to the FlagGems C++
implementation — no Python-level monkey patching required. This is the
lowest-friction way to accelerate an existing model.
Because
TORCH_LIBRARY_IMPLruns at module import time, the set of ops replaced this way is fixed at build time. Per-op opt-out is not currently supported through this path.
3.3 Via the c_operators pybind module (direct, dispatcher-free)#
The same C++ wrappers are also exported through a PYBIND11_MODULE(c_operators, …)
in src/flag_gems/csrc/cstub.cpp:
from flag_gems import c_operators
out = c_operators.mm(a, b)
c_operators.fused_add_rms_norm(input, residual, weight, eps)This path completely bypasses the PyTorch dispatcher, making it the lowest-overhead way to call a FlagGems C++ operator from Python. It is most useful in latency-critical microbenchmarks or tight inner loops where even the boxed dispatcher call is measurable.
3.4 Via the native C++ API (flag_gems:: functions and GTest)#
Every wrapper is also a regular C++ function in the flag_gems::
namespace, declared in include/flag_gems/operators.h and shipped in the
installed CMake target FlagGems::operators. Downstream C++ code can link
against this target and call the operators directly:
#include "flag_gems/operators.h"
at::Tensor c = flag_gems::mm_tensor(a, b);
at::Tensor y = flag_gems::rms_norm(x, weight, eps);This is exactly what the in-tree GTest suite under ctests/ uses (e.g.
ctests/test_triton_mm.cpp), and it is the right path when embedding
FlagGems into a non-Python C++ application or when writing C++ unit tests.
Summary#
| Path | Entry point | Dispatcher |
|---|---|---|
torch.ops.flag_gems.* | TORCH_LIBRARY(flag_gems, …) | Yes |
| ATen replacement | TORCH_LIBRARY_IMPL(aten, …) | Yes |
flag_gems.c_operators pybind | PYBIND11_MODULE(c_operators, …) | No |
| Native C++ API | flag_gems::* in operators.h | No |
Reference: Currently supported C++-wrapped operators#
The following operators currently have C++ wrappers shipped with FlagGems.
add(pointwise dynamic C++)div(pointwise dynamic C++)fill(pointwise dynamic C++)addmmmmbmmcatcontiguouscopyembeddingexponential_zerosargmaxmaxsumsoftmaxsorttopknonzerorms_normfused_add_rms_normrotary_embeddingflash_attn_varlen_funcreshape_and_cache_flashrwkv_mm_sparsityrwkv_ka_fusion
Operators marked as pointwise dynamic C++ are built only when the
-DFLAGGEMS_BUILD_POINTWISE_DYNAMIC_CPP=ONCMake option is enabled. See the install guide for details.
We are actively expanding this list as part of our ongoing performance optimization work.