Coverage for src/flag_gems/experimental_ops/i0.py: 0%

57 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-07 22:33 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def i0_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 

13 x = tl.load(x_ptr + offsets, mask=mask, other=0) 

14 x_f32 = x.to(tl.float32) 

15 ax = tl.abs(x_f32) 

16 

17 # Small region: |x| <= 3.75 

18 t = x_f32 / 3.75 

19 y = t * t 

20 p_small = 1.0 + y * ( 

21 3.5156229 

22 + y 

23 * ( 

24 3.0899424 

25 + y * (1.2067492 + y * (0.2659732 + y * (0.0360768 + y * 0.0045813))) 

26 ) 

27 ) 

28 

29 # Large region: |x| > 3.75 

30 yb = 3.75 / ax 

31 p_big = 0.39894228 + yb * ( 

32 0.01328592 

33 + yb 

34 * ( 

35 0.00225319 

36 + yb 

37 * ( 

38 -0.00157565 

39 + yb 

40 * ( 

41 0.00916281 

42 + yb 

43 * ( 

44 -0.02057706 

45 + yb * (0.02635537 + yb * (-0.01647633 + yb * 0.00392377)) 

46 ) 

47 ) 

48 ) 

49 ) 

50 ) 

51 # Avoid division by zero via masking; big branch only used when ax > 3.75 

52 res_big = tl.exp(ax) * p_big / tl.sqrt(ax) 

53 

54 use_small = ax <= 3.75 

55 res = tl.where(use_small, p_small, res_big) 

56 

57 # Store result; Triton will cast to the dtype of out_ptr as needed 

58 tl.store(out_ptr + offsets, res, mask=mask) 

59 

60 

61def _launch_i0(out: torch.Tensor, x: torch.Tensor): 

62 assert x.is_cuda and out.is_cuda, "Input and output must be CUDA tensors" 

63 assert ( 

64 out.numel() == x.numel() 

65 ), "Input and output must have the same number of elements" 

66 assert out.device == x.device, "Input and output must be on the same device" 

67 

68 x_in = x 

69 out_in = out 

70 

71 # Ensure floating point compute 

72 if not x_in.is_floating_point(): 

73 x_in = x_in.to(torch.get_default_dtype()) 

74 

75 # Cast input to match the desired output dtype if needed 

76 # (Compute will be done in fp32 inside kernel; store will cast to out dtype) 

77 if x_in.dtype != out_in.dtype: 

78 x_in = x_in.to(out_in.dtype) 

79 

80 x_contig = x_in.contiguous() 

81 out_was_noncontig = not out_in.is_contiguous() 

82 out_contig = out_in.contiguous() if out_was_noncontig else out_in 

83 

84 n_elements = out_contig.numel() 

85 BLOCK_SIZE = 1024 

86 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

87 

88 i0_kernel[grid](x_contig, out_contig, n_elements, BLOCK_SIZE=BLOCK_SIZE) 

89 

90 if out_was_noncontig: 

91 out_in.copy_(out_contig) 

92 return out_in 

93 

94 

95def i0(x: torch.Tensor): 

96 if not x.is_cuda: 

97 raise ValueError("i0: input tensor must be on CUDA device") 

98 # Result dtype follows PyTorch's floating type behavior; use input dtype if floating, otherwise default 

99 out_dtype = x.dtype if x.is_floating_point() else torch.get_default_dtype() 

100 out = torch.empty_like(x.to(dtype=out_dtype), dtype=out_dtype, device=x.device) 

101 _launch_i0(out, x) 

102 return out 

103 

104 

105def i0_out(x: torch.Tensor, out: torch.Tensor): 

106 if not (x.is_cuda and out.is_cuda): 

107 raise ValueError("i0_out: input and output tensors must be on CUDA device") 

108 if not out.is_floating_point(): 

109 raise TypeError("i0_out: output tensor must be a floating point type") 

110 if x.numel() != out.numel(): 

111 raise ValueError( 

112 "i0_out: input and output must have the same number of elements" 

113 ) 

114 _launch_i0(out, x) 

115 return out