Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/nan_to_num.py: 0%

27 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-19 02:32 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import tl_extra_shim 

8 

9from ..utils.pointwise_dynamic import pointwise_dynamic 

10 

11_isnan = tl_extra_shim.isnan 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13 

14 

15@pointwise_dynamic( 

16 is_tensor=[True, False, False, False], promotion_methods=[(0, "DEFAULT")] 

17) 

18@triton.jit 

19def nan_to_num_func(x, nan, posinf, neginf): 

20 x_nan = _isnan(x.to(tl.float32)) 

21 x_posinf = x == float("inf") 

22 x_neginf = x == -float("inf") 

23 x = tl.where(x_nan, nan, x) 

24 x = tl.where(x_posinf, posinf, x) 

25 x = tl.where(x_neginf, neginf, x) 

26 return x 

27 

28 

29# nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor 

30def nan_to_num(A, nan=None, posinf=None, neginf=None): 

31 logger.debug("GEMS NAN_TO_NUM TENSOR") 

32 if posinf is None: 

33 posinf = torch.finfo(A.dtype).max 

34 if neginf is None: 

35 neginf = torch.finfo(A.dtype).min 

36 if nan is None: 

37 nan = 0.0 

38 return nan_to_num_func(A, nan, posinf, neginf)