Coverage for src/flag_gems/experimental_ops/special_i1.py: 0%

68 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-19 02:32 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def special_i1_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 

13 x = tl.load(x_ptr + offsets, mask=mask, other=0.0) 

14 x_f32 = x.to(tl.float32) 

15 ax = tl.abs(x_f32) 

16 

17 # Small region: |x| <= 3.75 

18 y = x_f32 / 3.75 

19 y2 = y * y 

20 # Horner polynomial for small |x| 

21 p = 0.00032411 

22 p = 0.00301532 + y2 * p 

23 p = 0.02658733 + y2 * p 

24 p = 0.15084934 + y2 * p 

25 p = 0.51498869 + y2 * p 

26 p = 0.87890594 + y2 * p 

27 p = 0.5 + y2 * p 

28 ans_small = x_f32 * p 

29 

30 # Large region: |x| > 3.75 

31 # Use asymptotic expansion: I1(x) ~ exp(|x|)/sqrt(|x|) * poly(3.75/|x|) 

32 # Coefficients from Cephes 

33 t = 3.75 / tl.maximum(ax, 1e-20) 

34 q = -0.00420059 

35 q = 0.01787654 + t * q 

36 q = -0.02895312 + t * q 

37 q = 0.02282967 + t * q 

38 q = -0.01031555 + t * q 

39 q = 0.00163801 + t * q 

40 q = -0.00362018 + t * q 

41 q = -0.03988024 + t * q 

42 q = 0.39894228 + t * q 

43 pref = tl.exp(ax) / tl.sqrt(tl.maximum(ax, 1e-20)) 

44 ans_large = pref * q 

45 # I1 is odd 

46 ans_large = tl.where(x_f32 < 0, -ans_large, ans_large) 

47 

48 is_small = ax <= 3.75 

49 ans = tl.where(is_small, ans_small, ans_large) 

50 

51 # Cast back to input dtype and store 

52 tl.store(out_ptr + offsets, ans.to(x.dtype), mask=mask) 

53 

54 

55def _launch_special_i1(x: torch.Tensor, out: torch.Tensor): 

56 assert x.is_cuda and out.is_cuda, "Tensors must be CUDA tensors" 

57 assert ( 

58 x.numel() == out.numel() 

59 ), "Input and output must have the same number of elements" 

60 assert x.dtype == out.dtype, "Input and output must have the same dtype" 

61 

62 n_elements = x.numel() 

63 if n_elements == 0: 

64 return 

65 

66 BLOCK_SIZE = 1024 

67 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

68 special_i1_kernel[grid](x, out, n_elements, BLOCK_SIZE=BLOCK_SIZE) 

69 

70 

71def special_i1(self: torch.Tensor): 

72 x = self 

73 x_c = x.contiguous() 

74 out = torch.empty_like(x_c) 

75 _launch_special_i1(x_c, out) 

76 # If original was non-contiguous, return view with same shape 

77 if x.layout == torch.strided and x.is_contiguous(): 

78 return out 

79 else: 

80 return out.view_as(x) 

81 

82 

83def special_i1_out(self: torch.Tensor, out: torch.Tensor): 

84 x = self 

85 # Ensure dtypes and devices match expectations 

86 if out.dtype != x.dtype: 

87 raise TypeError("out dtype must match input dtype") 

88 if out.device != x.device: 

89 raise TypeError("out device must match input device") 

90 

91 x_c = x.contiguous() 

92 out_c = out.contiguous() 

93 _launch_special_i1(x_c, out_c) 

94 if out_c.data_ptr() != out.data_ptr(): 

95 out.copy_(out_c) 

96 return out