Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/nan_to_num.py: 0%
27 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import tl_extra_shim
9from ..utils.pointwise_dynamic import pointwise_dynamic
11_isnan = tl_extra_shim.isnan
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15@pointwise_dynamic(
16 is_tensor=[True, False, False, False], promotion_methods=[(0, "DEFAULT")]
17)
18@triton.jit
19def nan_to_num_func(x, nan, posinf, neginf):
20 x_nan = _isnan(x.to(tl.float32))
21 x_posinf = x == float("inf")
22 x_neginf = x == -float("inf")
23 x = tl.where(x_nan, nan, x)
24 x = tl.where(x_posinf, posinf, x)
25 x = tl.where(x_neginf, neginf, x)
26 return x
29# nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor
30def nan_to_num(A, nan=None, posinf=None, neginf=None):
31 logger.debug("GEMS NAN_TO_NUM TENSOR")
32 if posinf is None:
33 posinf = torch.finfo(A.dtype).max
34 if neginf is None:
35 neginf = torch.finfo(A.dtype).min
36 if nan is None:
37 nan = 0.0
38 return nan_to_num_func(A, nan, posinf, neginf)