Coverage for src/flag_gems/ops/nan_to_num.py: 73%
26 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import pointwise_dynamic, tl_extra_shim
9_isnan = tl_extra_shim.isnan
11logger = logging.getLogger(__name__)
14@pointwise_dynamic(
15 is_tensor=[True, False, False, False], promotion_methods=[(0, "DEFAULT")]
16)
17@triton.jit
18def nan_to_num_func(x, nan, posinf, neginf):
19 x_nan = _isnan(x.to(tl.float32))
20 x_posinf = x == float("inf")
21 x_neginf = x == -float("inf")
22 x = tl.where(x_nan, nan, x)
23 x = tl.where(x_posinf, posinf, x)
24 x = tl.where(x_neginf, neginf, x)
25 return x
28# nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor
29def nan_to_num(A, nan=None, posinf=None, neginf=None):
30 logger.debug("GEMS NAN_TO_NUM TENSOR")
31 if posinf is None:
32 posinf = torch.finfo(A.dtype).max
33 if neginf is None:
34 neginf = torch.finfo(A.dtype).min
35 if nan is None:
36 nan = 0.0
37 return nan_to_num_func(A, nan, posinf, neginf)