Coverage for src/flag_gems/runtime/backend/_metax/ops/prod.py: 0%

90 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-07 22:33 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry, libtuner 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13logger = logging.getLogger("flag_gems." + __name__) 

14 

15 

16@triton.jit 

17def reduce_mul(a, b): 

18 return a * b 

19 

20 

21@libentry() 

22@triton.jit 

23def prod_kernel_mid( 

24 inp, 

25 mid, 

26 M, 

27 BLOCK_SIZE: tl.constexpr, 

28): 

29 pid = tle.program_id(0) 

30 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

31 inp_ptrs = inp + offset 

32 mask = offset < M 

33 inp_val = tl.load(inp_ptrs, mask=mask, other=1.0).to(tl.float32) 

34 mid_value = tl.reduce(inp_val, axis=0, combine_fn=reduce_mul) 

35 mid_ptr = mid + pid 

36 tl.store(mid_ptr, mid_value.to(inp_val.dtype)) 

37 

38 

39@libentry() 

40@triton.jit 

41def prod_kernel_result(mid, out, mid_size, BLOCK_MID: tl.constexpr): 

42 offset = tl.arange(0, BLOCK_MID) 

43 mid_ptrs = mid + offset 

44 mask = offset < mid_size 

45 mid_val = tl.load(mid_ptrs, mask=mask, other=1.0).to(tl.float32) 

46 prod_val = tl.reduce(mid_val, axis=0, combine_fn=reduce_mul) 

47 tl.store(out, prod_val) 

48 

49 

50def prod(inp, *, dtype=None): 

51 logger.debug("METAX GEMS PROD") 

52 if dtype is None: 

53 dtype = inp.dtype 

54 

55 M = inp.numel() 

56 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

57 mid_size = triton.cdiv(M, block_size) 

58 block_mid = triton.next_power_of_2(mid_size) 

59 

60 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

61 out = torch.empty([], dtype=dtype, device=inp.device) 

62 

63 with torch_device_fn.device(inp.device): 

64 prod_kernel_mid[(mid_size, 1, 1)](inp, mid, M, block_size) 

65 prod_kernel_result[(1, 1, 1)](mid, out, mid_size, block_mid) 

66 return out 

67 

68 

69def heur_block_n(args): 

70 return triton.next_power_of_2(args["N"]) 

71 

72 

73@libentry() 

74@libtuner( 

75 configs=runtime.get_tuned_config("prod"), 

76 key=[ 

77 "M", 

78 "N", 

79 ], 

80) 

81@triton.heuristics( 

82 { 

83 "BLOCK_N": heur_block_n, 

84 } 

85) 

86@triton.jit 

87def prod_kernel( 

88 inp, 

89 out, 

90 M, 

91 N, 

92 K, 

93 BLOCK_M: tl.constexpr, 

94 BLOCK_N: tl.constexpr, 

95): 

96 # set offset 

97 pid_m = tle.program_id(0) 

98 pid_k = tle.program_id(1) 

99 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

100 

101 acc = tl.full((BLOCK_M, BLOCK_N), value=1.0, dtype=tl.float32) 

102 for start_n in range(0, N, BLOCK_N): 

103 n_offset = start_n + tl.arange(0, BLOCK_N) 

104 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k 

105 

106 # set mask 

107 mask = m_offset[:, None] < M and n_offset[None, :] < N 

108 inp_ptrs = inp + offset 

109 inp_vals = tl.load(inp_ptrs, mask=mask, other=1.0).to(tl.float32) 

110 acc *= inp_vals 

111 result_index = tl.reduce(acc, axis=1, combine_fn=reduce_mul) 

112 

113 offset_index = m_offset * K + pid_k 

114 out_ptrs = out + offset_index 

115 mask1 = m_offset < M 

116 tl.store(out_ptrs, result_index, mask=mask1) 

117 

118 

119def prod_dim(inp, dim=None, keepdim=False, *, dtype=None): 

120 logger.debug("METAX GEMS PROD DIM") 

121 

122 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

123 shape = inp.shape 

124 dim = dim % inp.ndim 

125 N = shape[dim] 

126 M = math.prod(shape[:dim]) 

127 K = inp.numel() // M // N 

128 

129 inp = inp.contiguous() 

130 

131 shape_list = list(shape) 

132 shape_list[dim] = 1 

133 

134 if dtype is None: 

135 dtype = inp.dtype 

136 out = torch.empty(shape_list, dtype=dtype, device=inp.device) 

137 if not keepdim: 

138 out = torch.squeeze(out, dim) 

139 

140 grid = lambda meta: ( 

141 triton.cdiv(M, meta["BLOCK_M"]), 

142 K, 

143 ) 

144 with torch_device_fn.device(inp.device): 

145 prod_kernel[grid](inp, out, M, N, K) 

146 

147 return out