Coverage for src/flag_gems/runtime/backend/_cambricon/ops/amax.py: 0%

133 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import dim_compress, libentry 

10from flag_gems.utils.shape_utils import can_use_int32_index 

11 

12from ..utils import TOTAL_CORE_NUM, cfggen_reduce_op 

13 

14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

15 

16 

17@libentry() 

18@triton.jit 

19def amax_kernel_once( 

20 inp, 

21 out, 

22 M: tl.constexpr, 

23): 

24 offset = tl.arange(0, M) 

25 inp_val = tl.load(inp + offset) 

26 amax_val = tl.max(inp_val, 0) 

27 tl.store(out, amax_val) 

28 

29 

30@libentry() 

31@triton.autotune(configs=cfggen_reduce_op(), key=["M"]) 

32@triton.jit 

33def amax_kernel_1( 

34 inp, 

35 out, 

36 M, 

37 BLOCK_SIZE: tl.constexpr, 

38 INT64_INDEX: tl.constexpr = False, 

39): 

40 pid = tl.program_id(0) 

41 if INT64_INDEX: 

42 pid = pid.to(tl.int64) 

43 num_jobs = tl.num_programs(axis=0) 

44 block_start = pid * BLOCK_SIZE 

45 step = num_jobs * BLOCK_SIZE 

46 _tmp = -float("inf") 

47 for off in range(block_start, M, step): 

48 offset = off + tl.arange(0, BLOCK_SIZE) 

49 mask = offset < M 

50 inp_val = tl.load(inp + offset, mask=mask, other=-float("inf")) 

51 (amax_val,) = tl.max(inp_val, 0, return_indices=True) 

52 if amax_val > _tmp: 

53 _tmp = amax_val.to(tl.float32) 

54 tl.atomic_max(out, _tmp) 

55 

56 

57@libentry() 

58@triton.autotune(configs=runtime.get_tuned_config("amax_opt"), key=["N"]) 

59@triton.jit 

60def amax_kernel_opt( 

61 inp, 

62 out, 

63 M: tl.constexpr, 

64 N: tl.constexpr, 

65 TILE_NUM_N: tl.constexpr, 

66 INT64_INDEX: tl.constexpr = False, 

67): 

68 # Map the program id to the row of inp it should compute. 

69 pid_m = tl.program_id(0) 

70 pid_n = tl.program_id(1) 

71 if INT64_INDEX: 

72 pid_m = pid_m.to(tl.int64) 

73 pid_n = pid_n.to(tl.int64) 

74 

75 num_jobs = tl.num_programs(0) 

76 rows_per_job = (M + num_jobs - 1) // num_jobs 

77 row_begin = pid_m * rows_per_job 

78 row_end = min(row_begin + rows_per_job, M) 

79 

80 BLOCK_N: tl.constexpr = (N + TILE_NUM_N - 1) // TILE_NUM_N 

81 

82 for row_idx in range(row_begin, row_end): 

83 offset_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

84 inp_ptrs = inp + row_idx * N + offset_n 

85 mask = offset_n < N 

86 inps = tl.load(inp_ptrs, mask, other=-float("inf")) 

87 (max_val,) = tl.max(inps, 0, return_indices=True) 

88 new_out = out + row_idx 

89 tl.atomic_max(new_out, max_val) 

90 

91 

92@libentry() 

93@triton.autotune(configs=runtime.get_tuned_config("amax"), key=["M", "N"]) 

94@triton.jit 

95def amax_kernel( 

96 inp, 

97 out, 

98 M, 

99 N, 

100 BLOCK_M: tl.constexpr, 

101 BLOCK_N: tl.constexpr, 

102 INT64_INDEX: tl.constexpr = False, 

103): 

104 # Map the program id to the row of inp it should compute. 

105 pid = tl.program_id(0) 

106 if INT64_INDEX: 

107 pid = pid.to(tl.int64) 

108 

109 num_jobs = tl.num_programs(axis=0) 

110 start_m = pid * BLOCK_M 

111 step = num_jobs * BLOCK_M 

112 for off_m in range(start_m, M, step): 

113 rows = off_m + tl.arange(0, BLOCK_M)[:, None] 

114 new_inp = inp + rows * N 

115 new_out = out + rows 

116 row_mask = rows < M 

117 

118 _all = tl.full([BLOCK_M, BLOCK_N], value=-float("inf"), dtype=tl.float32) 

119 for off in range(0, N, BLOCK_N): 

120 cols = off + tl.arange(0, BLOCK_N)[None, :] 

121 col_mask = cols < N 

122 mask = row_mask and col_mask 

123 

124 a = tl.load(new_inp + cols, mask, other=-float("inf")) 

125 _all = tl.maximum(a, _all) 

126 

127 all = tl.max(_all, axis=1)[:, None] 

128 tl.store(new_out, all, row_mask) 

129 

130 

131def amax(inp, dim=None, keepdim=False): 

132 logger.debug("GEMS_CAMBRICON AMAX") 

133 if dim is None or len(dim) == 0: 

134 M = inp.numel() 

135 dtype = inp.dtype 

136 use_int64_index = not can_use_int32_index(inp) 

137 

138 if M <= 65536: 

139 if not keepdim: 

140 out = torch.empty([], dtype=dtype, device=inp.device) 

141 else: 

142 shape = list(inp.shape) 

143 for i in range(0, inp.dim()): 

144 shape[i] = 1 

145 out = torch.empty(shape, dtype=dtype, device=inp.device) 

146 with torch.cuda.device(inp.device): 

147 amax_kernel_once[(1, 1, 1)](inp, out, M) 

148 return out 

149 else: 

150 outdtype = torch.float32 

151 if not keepdim: 

152 out = torch.full( 

153 [], torch.finfo(outdtype).min, dtype=outdtype, device=inp.device 

154 ) 

155 else: 

156 shape = list(inp.shape) 

157 for i in range(0, inp.dim()): 

158 shape[i] = 1 

159 out = torch.full( 

160 shape, torch.finfo(outdtype).min, dtype=outdtype, device=inp.device 

161 ) 

162 grid = lambda meta: ( 

163 min(triton.cdiv(M, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM), 

164 ) 

165 with torch_device_fn.device(inp.device): 

166 amax_kernel_1[grid](inp, out, M, INT64_INDEX=use_int64_index) 

167 return out.to(dtype) 

168 else: 

169 if isinstance(dim, int): 

170 dim = [dim] 

171 assert ((i >= -inp.ndim and i < inp.ndim) for i in dim), "Invalid dim" 

172 dtype = inp.dtype 

173 

174 shape = list(inp.shape) 

175 dim = [d % inp.ndim for d in dim] 

176 inp = dim_compress(inp, dim) 

177 use_int64_index = not can_use_int32_index(inp) 

178 N = 1 

179 for i in dim: 

180 N *= shape[i] 

181 shape[i] = 1 

182 M = inp.numel() // N 

183 

184 with torch_device_fn.device(inp.device): 

185 if N > 1048576: 

186 out = torch.empty(shape, dtype=dtype, device=inp.device) 

187 grid = lambda meta: ( 

188 min(triton.cdiv(M, meta["BLOCK_M"]), TOTAL_CORE_NUM), 

189 ) 

190 amax_kernel[grid](inp, out, M, N, INT64_INDEX=use_int64_index) 

191 else: 

192 out = torch.full( 

193 shape, 

194 torch.finfo(torch.float32).min, 

195 dtype=torch.float32, 

196 device=inp.device, 

197 ) 

198 grid = lambda meta: ( 

199 min(triton.cdiv(TOTAL_CORE_NUM, meta["TILE_NUM_N"]), M), 

200 meta["TILE_NUM_N"], 

201 ) 

202 amax_kernel_opt[grid](inp, out, M, N, INT64_INDEX=use_int64_index) 

203 if not keepdim: 

204 out = out.squeeze(dim=dim) 

205 return out.to(dtype)