Coverage for src/flag_gems/experimental_ops/i0_.py: 0%

44 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-20 02:31 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def i0_(x_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 

13 x = tl.load(x_ptr + offsets, mask=mask, other=0.0) 

14 xf = tl.cast(x, tl.float32) 

15 ax = tl.abs(xf) 

16 

17 t_small = ax / 3.75 

18 y_small = t_small * t_small 

19 poly_small = 1.0 + y_small * ( 

20 3.5156229 

21 + y_small 

22 * ( 

23 3.0899424 

24 + y_small 

25 * ( 

26 1.2067492 

27 + y_small * (0.2659732 + y_small * (0.0360768 + y_small * 0.0045813)) 

28 ) 

29 ) 

30 ) 

31 

32 y_large = 3.75 / ax 

33 poly_large = 0.39894228 + y_large * ( 

34 0.01328592 

35 + y_large 

36 * ( 

37 0.00225319 

38 + y_large 

39 * ( 

40 -0.00157565 

41 + y_large 

42 * ( 

43 0.00916281 

44 + y_large 

45 * ( 

46 -0.02057706 

47 + y_large 

48 * (0.02635537 + y_large * (-0.01647633 + y_large * 0.00392377)) 

49 ) 

50 ) 

51 ) 

52 ) 

53 ) 

54 val_large = tl.exp(ax) * poly_large / tl.sqrt(ax) 

55 

56 result = tl.where(ax <= 3.75, poly_small, val_large) 

57 

58 result_cast = tl.cast(result, x.dtype) 

59 tl.store(x_ptr + offsets, result_cast, mask=mask) 

60 

61 

62# Keep a reference to the Triton kernel before defining the Python wrapper with the same name 

63i0__kernel = i0_ 

64 

65 

66def i0_(*args, **kwargs): 

67 x = None 

68 if len(args) > 0: 

69 x = args[0] 

70 else: 

71 # Try common keyword names 

72 for k in ("input", "self", "x"): 

73 if k in kwargs: 

74 x = kwargs[k] 

75 break 

76 if x is None: 

77 raise ValueError( 

78 "i0_ expects a tensor as the first positional argument or in keyword 'input'/'self'/'x'." 

79 ) 

80 

81 if not x.is_cuda: 

82 raise AssertionError("Input tensor must be on a CUDA device.") 

83 if not x.is_contiguous(): 

84 raise AssertionError("Input tensor must be contiguous.") 

85 if x.dtype not in (torch.float16, torch.bfloat16, torch.float32, torch.float64): 

86 raise AssertionError( 

87 "Unsupported dtype for i0_. Supported: float16, bfloat16, float32, float64." 

88 ) 

89 

90 n_elements = x.numel() 

91 if n_elements == 0: 

92 return x 

93 

94 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

95 i0__kernel[grid](x, n_elements, BLOCK_SIZE=1024) 

96 return x