Coverage for src/flag_gems/experimental_ops/erfinv.py: 0%

78 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-20 02:31 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def erfinv_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 

13 x = tl.load(x_ptr + offsets, mask=mask, other=0.0) 

14 xf = x.to(tl.float32) 

15 

16 one = 1.0 

17 absx = tl.abs(xf) 

18 w = -tl.log((one - xf) * (one + xf)) 

19 

20 use_low = w < 5.0 

21 

22 wl = w - 2.5 

23 pl = 2.81022636e-08 

24 pl = 3.43273939e-07 + pl * wl 

25 pl = -3.5233877e-06 + pl * wl 

26 pl = -4.39150654e-06 + pl * wl 

27 pl = 2.1858087e-04 + pl * wl 

28 pl = -1.25372503e-03 + pl * wl 

29 pl = -4.17768164e-03 + pl * wl 

30 pl = 2.46640727e-01 + pl * wl 

31 pl = 1.50140941e00 + pl * wl 

32 

33 wh = tl.sqrt(w) - 3.0 

34 ph = -2.00214257e-04 

35 ph = 1.00950558e-04 + ph * wh 

36 ph = 1.34934322e-03 + ph * wh 

37 ph = -3.67342844e-03 + ph * wh 

38 ph = 5.73950773e-03 + ph * wh 

39 ph = -7.62246130e-03 + ph * wh 

40 ph = 9.43887047e-03 + ph * wh 

41 ph = 1.00167406e00 + ph * wh 

42 ph = 2.83297682e00 + ph * wh 

43 

44 p = tl.where(use_low, pl, ph) 

45 res = p * xf 

46 

47 nan_vec = tl.full([BLOCK_SIZE], float("nan"), dtype=tl.float32) 

48 inf_vec = tl.full([BLOCK_SIZE], float("inf"), dtype=tl.float32) 

49 

50 mask_nan = xf != xf 

51 mask_oob = absx > 1.0 

52 mask_pos1 = xf == 1.0 

53 mask_neg1 = xf == -1.0 

54 

55 res = tl.where(mask_nan, nan_vec, res) 

56 res = tl.where(mask_oob, nan_vec, res) 

57 res = tl.where(mask_pos1, inf_vec, res) 

58 res = tl.where(mask_neg1, -inf_vec, res) 

59 

60 y = res.to(x.dtype) 

61 tl.store(out_ptr + offsets, y, mask=mask) 

62 

63 

64def _launch_erfinv_kernel(x: torch.Tensor, out: torch.Tensor): 

65 assert x.is_cuda and out.is_cuda, "Inputs must be CUDA tensors" 

66 assert ( 

67 x.numel() == out.numel() 

68 ), "Input and output must have the same number of elements" 

69 assert x.dtype == out.dtype, "Input and output must have the same dtype" 

70 n_elements = x.numel() 

71 BLOCK_SIZE = 1024 

72 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

73 erfinv_kernel[grid]( 

74 x, 

75 out, 

76 n_elements, 

77 BLOCK_SIZE=BLOCK_SIZE, 

78 ) 

79 

80 

81def erfinv(x: torch.Tensor): 

82 x_in = x 

83 if not x_in.is_contiguous(): 

84 x_in = x_in.contiguous() 

85 out = torch.empty_like(x_in) 

86 _launch_erfinv_kernel(x_in, out) 

87 # Match original shape/strides of input if needed 

88 if out.shape != x.shape or out.stride() != x.stride(): 

89 out = out.reshape(x.shape).as_strided(x.size(), x.stride()) 

90 return out 

91 

92 

93def erfinv_out(x: torch.Tensor, out: torch.Tensor): 

94 # Resize out to match input shape if necessary 

95 if out.shape != x.shape: 

96 out.resize_(x.shape) 

97 # Ensure dtype matches input dtype for aten out semantics 

98 assert out.dtype == x.dtype, "out tensor must have the same dtype as input" 

99 x_in = x if x.is_contiguous() else x.contiguous() 

100 if out.is_contiguous(): 

101 _launch_erfinv_kernel(x_in, out) 

102 return out 

103 else: 

104 tmp = torch.empty_like(out, memory_format=torch.contiguous_format) 

105 _launch_erfinv_kernel(x_in, tmp) 

106 out.copy_(tmp) 

107 return out