Precision Checking (Experimental)#
FlagGems provides an experimental precision-checking mechanism that automatically compares the output of FlagGems operators against native PyTorch (CPU) results, and logs any discrepancies to a file. This is useful for verifying numerical correctness during development.
How to Enable#
Enabling precision checking requires two steps:
- Call
enable_precision_check()fromflag_gems.logging_utilsto configure the precision logger. - Pass
PrecisionCheckRegisteras theregistrarparameter toenable()oronly_enable(), so that operators are wrapped with precision-checking logic.
import flag_gems
from flag_gems.logging_utils import enable_precision_check
from flag_gems.runtime.precision_register import PrecisionCheckRegister
# Step 1: Configure precision checking (initialize the precision logger)
enable_precision_check()
# Step 2: Register all operators with PrecisionCheckRegister
flag_gems.enable(registrar=PrecisionCheckRegister)
# Run your model or operators as usual
output = model(input)You can also use it with only_enable() to check specific operators:
from flag_gems.logging_utils import enable_precision_check
from flag_gems.runtime.precision_register import PrecisionCheckRegister
enable_precision_check(rtol=1e-3, atol=1e-4)
flag_gems.only_enable(
include=["mm", "add", "softmax"],
registrar=PrecisionCheckRegister,
)Configuration#
You can customize the precision checking behavior by passing parameters
to enable_precision_check().
| Parameter | Type | Default | Description |
|---|---|---|---|
rtol | float | 1e-4 | Relative tolerance |
atol | float | 1e-5 | Absolute tolerance |
max_checks | int | 10 | Max checks per operator before skipping |
log_once | bool | True | Only log the first failure per operator |
path | str | ~/.flaggems/precision.log | Log file path |
from flag_gems.logging_utils import enable_precision_check
enable_precision_check(
rtol=1e-3,
atol=1e-5,
max_checks=20,
path="./my_precision.log",
)Disabling#
To disable precision checking at runtime:
from flag_gems.logging_utils import disable_precision_check
disable_precision_check()Log Output#
Precision check results are written to ~/.flaggems/precision.log by default.
Only operators that fail the tolerance check will be logged.
Sample log content:
$ cat ~/.flaggems/precision.log
2025-05-19 10:00:01 [WARNING] Op: add.Tensor | FAIL | in: [(2, 3):torch.float16] | out: (2, 3):torch.float16 | max_abs: 1.200000e-03 | max_rel: 2.500000e-02 | rtol=0.01, atol=0.01Behavior Details#
The precision checker has several built-in safeguards to minimize performance impact:
- Only the first N calls per operator are checked (controlled by
max_checks) - Tensors larger than 1M elements are skipped to avoid copy overhead
- Once an operator logs a failure, it will not be checked again
- Pure layout/memory ops (
clone,view,copy_, etc.) are automatically skipped - Random sampling ops (
uniform_,normal_, etc.) are automatically skipped .outvariant operators are skipped- For
float16/bfloat16inputs, tolerance is automatically relaxed to at least1e-2
How It Works#
When PrecisionCheckRegister is used as the registrar, each operator
is wrapped with a precision-checking decorator. The wrapper:
- Executes the FlagGems (GPU) implementation normally.
- Copies inputs to CPU and runs the native
atenoperator as reference. - Compares the two results using the configured tolerance.
- Logs a warning if the results diverge beyond tolerance.
Precision checking copies GPU tensors to CPU and runs native PyTorch computation as a reference. This incurs significant performance overhead. This feature is intended for development and debugging only — do not enable it in production.