Coverage for src/flag_gems/ops/i0_.py: 46%

46 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8logger = logging.getLogger(__name__) 

9 

10 

11@triton.jit 

12def i0_kernel_(x_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

13 pid = tl.program_id(axis=0) 

14 block_start = pid * BLOCK_SIZE 

15 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

16 mask = offsets < n_elements 

17 

18 x = tl.load(x_ptr + offsets, mask=mask, other=0.0) 

19 xf = tl.cast(x, tl.float32) 

20 ax = tl.abs(xf) 

21 

22 t_small = ax / 3.75 

23 y_small = t_small * t_small 

24 poly_small = 1.0 + y_small * ( 

25 3.5156229 

26 + y_small 

27 * ( 

28 3.0899424 

29 + y_small 

30 * ( 

31 1.2067492 

32 + y_small * (0.2659732 + y_small * (0.0360768 + y_small * 0.0045813)) 

33 ) 

34 ) 

35 ) 

36 

37 y_large = 3.75 / ax 

38 poly_large = 0.39894228 + y_large * ( 

39 0.01328592 

40 + y_large 

41 * ( 

42 0.00225319 

43 + y_large 

44 * ( 

45 -0.00157565 

46 + y_large 

47 * ( 

48 0.00916281 

49 + y_large 

50 * ( 

51 -0.02057706 

52 + y_large 

53 * (0.02635537 + y_large * (-0.01647633 + y_large * 0.00392377)) 

54 ) 

55 ) 

56 ) 

57 ) 

58 ) 

59 val_large = tl.exp(ax) * poly_large / tl.sqrt(ax) 

60 

61 result = tl.where(ax <= 3.75, poly_small, val_large) 

62 

63 result_cast = tl.cast(result, x.dtype) 

64 tl.store(x_ptr + offsets, result_cast, mask=mask) 

65 

66 

67def i0_(*args, **kwargs): 

68 logger.debug("GEMS I0_") 

69 x = None 

70 if len(args) > 0: 

71 x = args[0] 

72 else: 

73 # Try common keyword names 

74 for k in ("input", "self", "x"): 

75 if k in kwargs: 

76 x = kwargs[k] 

77 break 

78 if x is None: 

79 raise ValueError( 

80 "i0_ expects a tensor as the first positional argument or in keyword 'input'/'self'/'x'." 

81 ) 

82 

83 if not x.is_cuda: 

84 raise AssertionError("Input tensor must be on a CUDA device.") 

85 if not x.is_contiguous(): 

86 raise AssertionError("Input tensor must be contiguous.") 

87 if x.dtype not in (torch.float16, torch.bfloat16, torch.float32, torch.float64): 

88 raise AssertionError( 

89 "Unsupported dtype for i0_. Supported: float16, bfloat16, float32, float64." 

90 ) 

91 

92 n_elements = x.numel() 

93 if n_elements == 0: 

94 return x 

95 

96 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

97 i0_kernel_[grid](x, n_elements, BLOCK_SIZE=1024) 

98 return x