Coverage for src/flag_gems/runtime/backend/_cambricon/ops/prod.py: 0%

97 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-15 02:11 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11 

12from ..utils import TOTAL_CORE_NUM, cfggen_reduce_op2, count_divisible_by_2 

13 

14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

15 

16 

17@triton.jit 

18def reduce_mul(a, b): 

19 return a * b 

20 

21 

22@libentry() 

23@triton.autotune(configs=cfggen_reduce_op2(), key=["M"]) 

24@triton.jit 

25def prod_kernel_mid( 

26 inp, 

27 mid, 

28 M, 

29 BLOCK_SIZE: tl.constexpr, 

30 ITER_NUM: tl.constexpr, 

31): 

32 pid = tl.program_id(0) 

33 num_jobs = tl.num_programs(axis=0) 

34 block_start = pid * BLOCK_SIZE 

35 step = num_jobs * BLOCK_SIZE 

36 _tmp = tl.full([BLOCK_SIZE], value=1.0, dtype=tl.float32) 

37 block_start = block_start.to(tl.int64) 

38 for off in range(block_start, M, step): 

39 offset = off + tl.arange(0, BLOCK_SIZE) 

40 mask = offset < M 

41 inp_val = tl.load(inp + offset, mask=mask, other=1.0).to(tl.float32) 

42 _tmp = inp_val * _tmp 

43 

44 # Reset to original reduce programming mode after optimizing the tl.reduce. 

45 for x in tl.static_range(1, int(ITER_NUM), 1): 

46 _tmp[: BLOCK_SIZE // (2**x)] = ( 

47 _tmp[: BLOCK_SIZE // (2**x)] 

48 * _tmp[BLOCK_SIZE // (2**x) : (BLOCK_SIZE // (2**x)) * 2] 

49 ) 

50 

51 mid_ptr = mid + pid 

52 tl.store(mid_ptr, _tmp[0]) 

53 

54 

55@libentry() 

56@triton.jit 

57def prod_kernel_result(mid, out, mid_size: tl.constexpr, loop_num: tl.constexpr): 

58 offset = tl.arange(0, mid_size) 

59 mid_val = tl.load(mid + offset) 

60 

61 # Reset to original reduce programming mode after optimizing the tl.reduce. 

62 for x in tl.static_range(1, loop_num, 1): 

63 mid_val[: mid_size // (2**x)] = ( 

64 mid_val[: mid_size // (2**x)] 

65 * mid_val[mid_size // (2**x) : (mid_size // (2**x)) * 2] 

66 ) 

67 

68 prod_val = tl.reduce( 

69 mid_val[: mid_size // (2 ** (loop_num - 1))], axis=0, combine_fn=reduce_mul 

70 ) 

71 tl.store(out, prod_val) 

72 

73 

74def prod(inp, *, dtype=None): 

75 logger.debug("GEMS_CAMBRICON PROD") 

76 if dtype is None: 

77 dtype = inp.dtype 

78 

79 M = inp.numel() 

80 grid = lambda meta: (min(triton.cdiv(M, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) 

81 mid_size = TOTAL_CORE_NUM 

82 loop_num = count_divisible_by_2(mid_size) + 1 

83 

84 mid = torch.ones((mid_size,), dtype=dtype, device=inp.device) 

85 out = torch.empty([], dtype=dtype, device=inp.device) 

86 

87 with torch_device_fn.device(inp.device): 

88 prod_kernel_mid[grid](inp, mid, M) 

89 prod_kernel_result[(1, 1, 1)](mid, out, mid_size, loop_num) 

90 return out 

91 

92 

93def heur_block_n(args): 

94 return triton.next_power_of_2(args["N"]) 

95 

96 

97@libentry() 

98@triton.autotune( 

99 configs=runtime.get_tuned_config("prod"), 

100 key=[ 

101 "M", 

102 "N", 

103 ], 

104) 

105@triton.jit 

106def prod_kernel( 

107 inp, 

108 out, 

109 M, 

110 N, 

111 K, 

112 BLOCK_M: tl.constexpr, 

113 BLOCK_N: tl.constexpr, 

114): 

115 # set offset 

116 pid_m = tl.program_id(0) 

117 pid_k = tl.program_id(1) 

118 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

119 

120 acc = tl.full((BLOCK_M, BLOCK_N), value=1.0, dtype=tl.float32) 

121 for start_n in range(0, N, BLOCK_N): 

122 n_offset = start_n + tl.arange(0, BLOCK_N) 

123 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k 

124 

125 # set mask 

126 mask = (m_offset[:, None] < M) & (n_offset[None, :] < N) 

127 inp_ptrs = inp + offset 

128 inp_vals = tl.load(inp_ptrs, mask=mask, other=1.0).to(tl.float32) 

129 acc *= inp_vals 

130 result_index = tl.reduce(acc, axis=1, combine_fn=reduce_mul) 

131 

132 offset_index = m_offset * K + pid_k 

133 out_ptrs = out + offset_index 

134 mask1 = m_offset < M 

135 tl.store(out_ptrs, result_index, mask=mask1) 

136 

137 

138def prod_dim(inp, dim=None, keepdim=False, *, dtype=None): 

139 logger.debug("GEMS_CAMBRICON PROD DIM") 

140 

141 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

142 shape = inp.shape 

143 dim = dim % inp.ndim 

144 N = shape[dim] 

145 M = math.prod(shape[:dim]) 

146 K = inp.numel() // M // N 

147 

148 inp = inp.contiguous() 

149 

150 shape_list = list(shape) 

151 shape_list[dim] = 1 

152 

153 if dtype is None: 

154 dtype = inp.dtype 

155 out = torch.empty(shape_list, dtype=dtype, device=inp.device) 

156 if not keepdim: 

157 out = torch.squeeze(out, dim) 

158 

159 grid = lambda meta: ( 

160 triton.cdiv(M, meta["BLOCK_M"]), 

161 K, 

162 ) 

163 with torch_device_fn.device(inp.device): 

164 prod_kernel[grid](inp, out, M, N, K) 

165 

166 return out