Coverage for src/flag_gems/ops/nan_to_num.py: 73%

26 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-17 02:35 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import pointwise_dynamic, tl_extra_shim 

8 

9_isnan = tl_extra_shim.isnan 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14@pointwise_dynamic( 

15 is_tensor=[True, False, False, False], promotion_methods=[(0, "DEFAULT")] 

16) 

17@triton.jit 

18def nan_to_num_func(x, nan, posinf, neginf): 

19 x_nan = _isnan(x.to(tl.float32)) 

20 x_posinf = x == float("inf") 

21 x_neginf = x == -float("inf") 

22 x = tl.where(x_nan, nan, x) 

23 x = tl.where(x_posinf, posinf, x) 

24 x = tl.where(x_neginf, neginf, x) 

25 return x 

26 

27 

28# nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor 

29def nan_to_num(A, nan=None, posinf=None, neginf=None): 

30 logger.debug("GEMS NAN_TO_NUM TENSOR") 

31 if posinf is None: 

32 posinf = torch.finfo(A.dtype).max 

33 if neginf is None: 

34 neginf = torch.finfo(A.dtype).min 

35 if nan is None: 

36 nan = 0.0 

37 return nan_to_num_func(A, nan, posinf, neginf)