Coverage for src/flag_gems/experimental_ops/special_i0e.py: 0%

45 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-16 02:02 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def _special_i0e_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 

13 x = tl.load(x_ptr + offsets, mask=mask) 

14 

15 # Compute in fp32 for accuracy/stability 

16 xf = x.to(tl.float32) 

17 ax = tl.abs(xf) 

18 

19 # Small region: x <= 3.75 

20 t_small = ax / 3.75 

21 t2 = t_small * t_small 

22 # Polynomial approximation for I0 in small region (Numerical Recipes) 

23 p = 1.0 + t2 * ( 

24 3.5156229 

25 + t2 

26 * ( 

27 3.0899424 

28 + t2 * (1.2067492 + t2 * (0.2659732 + t2 * (0.0360768 + t2 * 0.0045813))) 

29 ) 

30 ) 

31 small = p * tl.exp(-ax) 

32 

33 # Large region: x > 3.75, use asymptotic expansion to avoid exp overflow 

34 # i0e(x) = I0(x)*exp(-|x|) ≈ (1/sqrt(|x|)) * poly(3.75/|x|) 

35 t = 3.75 / ax 

36 q = 0.39894228 + t * ( 

37 0.01328592 

38 + t 

39 * ( 

40 0.00225319 

41 + t 

42 * ( 

43 -0.00157565 

44 + t 

45 * ( 

46 0.00916281 

47 + t 

48 * ( 

49 -0.02057706 

50 + t * (0.02635537 + t * (-0.01647633 + t * 0.00392377)) 

51 ) 

52 ) 

53 ) 

54 ) 

55 ) 

56 large = q / tl.sqrt(ax) 

57 

58 is_large = ax > 3.75 

59 y = tl.where(is_large, large, small) 

60 

61 # Cast back to input dtype for storage 

62 y = y.to(x.dtype) 

63 tl.store(out_ptr + offsets, y, mask=mask) 

64 

65 

66def _run_special_i0e_kernel(x: torch.Tensor, out: torch.Tensor): 

67 assert x.is_cuda and out.is_cuda, "Tensors must be CUDA tensors" 

68 assert x.dtype in ( 

69 torch.float16, 

70 torch.bfloat16, 

71 torch.float32, 

72 torch.float64, 

73 ), "Unsupported dtype" 

74 assert out.dtype == x.dtype, "Output dtype must match input dtype" 

75 

76 x_c = x.contiguous() 

77 out_c = out.contiguous() 

78 

79 n_elements = out_c.numel() 

80 if n_elements == 0: 

81 return out 

82 

83 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

84 _special_i0e_kernel[grid](x_c, out_c, n_elements, BLOCK_SIZE=1024) 

85 

86 if out_c.data_ptr() != out.data_ptr(): 

87 out.copy_(out_c) 

88 return out 

89 

90 

91def special_i0e(x: torch.Tensor): 

92 """ 

93 ATen wrapper: special_i0e(Tensor self) -> Tensor 

94 """ 

95 out = torch.empty_like(x) 

96 return _run_special_i0e_kernel(x, out) 

97 

98 

99def special_i0e_out(x: torch.Tensor, out: torch.Tensor): 

100 """ 

101 ATen wrapper: special_i0e.out(Tensor self, Tensor out) -> Tensor 

102 """ 

103 # Broadcast input to out's shape if needed 

104 if x.shape != out.shape: 

105 x = x.expand(out.shape) 

106 _run_special_i0e_kernel(x, out) 

107 return out