Coverage for src/flag_gems/ops/special_i0e.py: 53%

45 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import torch 

3import triton 

4import triton.language as tl 

5 

6 

7@triton.jit 

8def _special_i0e_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

9 pid = tl.program_id(axis=0) 

10 block_start = pid * BLOCK_SIZE 

11 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

12 mask = offsets < n_elements 

13 

14 x = tl.load(x_ptr + offsets, mask=mask) 

15 

16 # Compute in fp32 for accuracy/stability 

17 xf = x.to(tl.float32) 

18 ax = tl.abs(xf) 

19 

20 # Small region: x <= 3.75 

21 t_small = ax / 3.75 

22 t2 = t_small * t_small 

23 # Polynomial approximation for I0 in small region (Numerical Recipes) 

24 p = 1.0 + t2 * ( 

25 3.5156229 

26 + t2 

27 * ( 

28 3.0899424 

29 + t2 * (1.2067492 + t2 * (0.2659732 + t2 * (0.0360768 + t2 * 0.0045813))) 

30 ) 

31 ) 

32 small = p * tl.exp(-ax) 

33 

34 # Large region: x > 3.75, use asymptotic expansion to avoid exp overflow 

35 # i0e(x) = I0(x)*exp(-|x|) ≈ (1/sqrt(|x|)) * poly(3.75/|x|) 

36 t = 3.75 / ax 

37 q = 0.39894228 + t * ( 

38 0.01328592 

39 + t 

40 * ( 

41 0.00225319 

42 + t 

43 * ( 

44 -0.00157565 

45 + t 

46 * ( 

47 0.00916281 

48 + t 

49 * ( 

50 -0.02057706 

51 + t * (0.02635537 + t * (-0.01647633 + t * 0.00392377)) 

52 ) 

53 ) 

54 ) 

55 ) 

56 ) 

57 large = q / tl.sqrt(ax) 

58 

59 is_large = ax > 3.75 

60 y = tl.where(is_large, large, small) 

61 

62 # Cast back to input dtype for storage 

63 y = y.to(x.dtype) 

64 tl.store(out_ptr + offsets, y, mask=mask) 

65 

66 

67def _run_special_i0e_kernel(x: torch.Tensor, out: torch.Tensor): 

68 assert x.is_cuda and out.is_cuda, "Tensors must be CUDA tensors" 

69 assert x.dtype in ( 

70 torch.float16, 

71 torch.bfloat16, 

72 torch.float32, 

73 torch.float64, 

74 ), "Unsupported dtype" 

75 assert out.dtype == x.dtype, "Output dtype must match input dtype" 

76 

77 x_c = x.contiguous() 

78 out_c = out.contiguous() 

79 

80 n_elements = out_c.numel() 

81 if n_elements == 0: 

82 return out 

83 

84 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

85 _special_i0e_kernel[grid](x_c, out_c, n_elements, BLOCK_SIZE=1024) 

86 

87 if out_c.data_ptr() != out.data_ptr(): 

88 out.copy_(out_c) 

89 return out 

90 

91 

92def special_i0e(x: torch.Tensor): 

93 """ 

94 ATen wrapper: special_i0e(Tensor self) -> Tensor 

95 """ 

96 out = torch.empty_like(x) 

97 return _run_special_i0e_kernel(x, out) 

98 

99 

100def special_i0e_out(x: torch.Tensor, out: torch.Tensor): 

101 """ 

102 ATen wrapper: special_i0e.out(Tensor self, Tensor out) -> Tensor 

103 """ 

104 # Broadcast input to out's shape if needed 

105 if x.shape != out.shape: 

106 x = x.expand(out.shape) 

107 _run_special_i0e_kernel(x, out) 

108 return out