Coverage for src/flag_gems/runtime/backend/_mthreads/ops/max.py: 0%

150 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-20 02:31 +0800

1import builtins 

2import logging 

3import math 

4from collections import namedtuple 

5 

6import torch 

7import triton 

8import triton.language as tl 

9 

10from flag_gems.ops.max import max as base_max 

11from flag_gems.ops.max import max_dim as base_max_dim 

12from flag_gems.runtime import torch_device_fn 

13from flag_gems.utils import libentry 

14from flag_gems.utils import triton_lang_extension as tle 

15from flag_gems.utils.limits import get_dtype_min 

16 

17logger = logging.getLogger( 

18 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}' 

19) 

20 

21MaxOut = namedtuple("max", ["values", "indices"]) 

22 

23MAX_REDUCTION_CONFIGS = [ 

24 triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=4, num_stages=1), 

25 triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=8, num_stages=1), 

26 triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=8, num_stages=2), 

27 triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_warps=8, num_stages=2), 

28 triton.Config({"BLOCK_M": 16, "BLOCK_N": 256}, num_warps=4, num_stages=2), 

29 triton.Config({"BLOCK_M": 16, "BLOCK_N": 512}, num_warps=8, num_stages=2), 

30 triton.Config({"BLOCK_M": 32, "BLOCK_N": 512}, num_warps=8, num_stages=2), 

31 triton.Config({"BLOCK_M": 16, "BLOCK_N": 1024}, num_warps=8, num_stages=2), 

32 triton.Config({"BLOCK_M": 32, "BLOCK_N": 1024}, num_warps=8, num_stages=2), 

33 triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1), 

34] 

35 

36 

37def _prune_reduction_configs(configs, nargs, **meta): 

38 n = meta.get("N", nargs["N"]) 

39 if n <= 128: 

40 max_block_n = 128 

41 elif n <= 2048: 

42 max_block_n = 256 

43 elif n <= 8192: 

44 max_block_n = 512 

45 else: 

46 max_block_n = 1024 

47 return [cfg for cfg in configs if cfg.kwargs["BLOCK_N"] <= max_block_n] 

48 

49 

50def _flatten_dim(shape, dim): 

51 dim = dim % len(shape) 

52 n = shape[dim] 

53 inner = math.prod(shape[dim + 1 :]) if dim + 1 < len(shape) else 1 

54 outer = math.prod(shape[:dim]) if dim > 0 else 1 

55 return dim, n, inner, outer 

56 

57 

58@libentry() 

59@triton.jit 

60def max_kernel_1( 

61 inp, 

62 mid, 

63 M, 

64 BLOCK_SIZE: tl.constexpr, 

65): 

66 pid = tle.program_id(0) 

67 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

68 mask = offset < M 

69 min_value = get_dtype_min(inp.type.element_ty) 

70 vals = tl.load(inp + offset, mask=mask, other=min_value, cache_modifier=".cg") 

71 tl.store(mid + pid, tl.max(vals)) 

72 

73 

74@libentry() 

75@triton.jit 

76def max_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr): 

77 offset = tl.arange(0, BLOCK_MID) 

78 mask = offset < mid_size 

79 min_value = get_dtype_min(mid.type.element_ty) 

80 vals = tl.load(mid + offset, mask=mask, other=min_value) 

81 tl.store(out, tl.max(vals)) 

82 

83 

84@libentry() 

85@triton.jit 

86def max_kernel_small( 

87 inp, 

88 out_value, 

89 out_index, 

90 M, 

91 N, 

92 STRIDE_OUTER, 

93 STRIDE_REDUCE, 

94 BLOCK_N: tl.constexpr, 

95): 

96 row = tle.program_id(0) 

97 row_mask = row < M 

98 cols = tl.arange(0, BLOCK_N) 

99 col_mask = cols < N 

100 

101 stride_outer = tl.full((), STRIDE_OUTER, tl.int64) 

102 stride_reduce = tl.full((), STRIDE_REDUCE, tl.int64) 

103 offsets = row.to(tl.int64) * stride_outer + cols.to(tl.int64) * stride_reduce 

104 

105 dtype = inp.type.element_ty 

106 acc_type = tl.float32 if (dtype is tl.float16 or dtype is tl.bfloat16) else dtype 

107 min_value = get_dtype_min(dtype) 

108 vals = tl.load(inp + offsets, mask=row_mask & col_mask, other=min_value).to( 

109 acc_type 

110 ) 

111 row_max, row_argmax = tl.max( 

112 vals, 

113 axis=0, 

114 return_indices=True, 

115 return_indices_tie_break_left=True, 

116 ) 

117 tl.store(out_value + row, row_max, mask=row_mask) 

118 tl.store(out_index + row, row_argmax.to(tl.int32), mask=row_mask) 

119 

120 

121@libentry() 

122@triton.autotune( 

123 configs=MAX_REDUCTION_CONFIGS, 

124 key=["M", "N"], 

125 warmup=8, 

126 rep=40, 

127 prune_configs_by={"early_config_prune": _prune_reduction_configs}, 

128) 

129@triton.jit 

130def max_kernel( 

131 inp, 

132 out_value, 

133 out_index, 

134 M, 

135 N, 

136 INNER, 

137 STRIDE_OUTER, 

138 STRIDE_REDUCE, 

139 BLOCK_M: tl.constexpr, 

140 BLOCK_N: tl.constexpr, 

141): 

142 pid_m = tle.program_id(0) 

143 rows = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

144 rows = rows.to(tl.int64) 

145 row_mask = rows < M 

146 

147 outer_idx = rows // INNER 

148 inner_idx = rows % INNER 

149 base_ptr = inp + outer_idx * STRIDE_OUTER + inner_idx 

150 

151 dtype = inp.type.element_ty 

152 acc_type = tl.float32 if (dtype is tl.float16 or dtype is tl.bfloat16) else dtype 

153 min_value = get_dtype_min(dtype) 

154 max_values = tl.full([BLOCK_M], dtype=acc_type, value=min_value) 

155 argmax_values = tl.full([BLOCK_M], dtype=tl.int32, value=0) 

156 

157 for start_n in range(0, N, BLOCK_N): 

158 n_offset = start_n + tl.arange(0, BLOCK_N) 

159 n_offset = n_offset.to(tl.int64) 

160 mask = row_mask[:, None] & (n_offset[None, :] < N) 

161 inp_ptrs = base_ptr[:, None] + n_offset[None, :] * STRIDE_REDUCE 

162 inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value, cache_modifier=".cg") 

163 inp_vals = inp_vals.to(acc_type) 

164 local_max, local_argmax = tl.max( 

165 inp_vals, 

166 axis=1, 

167 return_indices=True, 

168 return_indices_tie_break_left=True, 

169 ) 

170 local_argmax = local_argmax.to(tl.int32) 

171 update = local_max > max_values 

172 max_values = tl.where(update, local_max, max_values) 

173 argmax_values = tl.where( 

174 update, (start_n + local_argmax).to(tl.int32), argmax_values 

175 ) 

176 

177 out_value_ptrs = out_value + rows 

178 out_index_ptrs = out_index + rows 

179 tl.store(out_value_ptrs, max_values, mask=row_mask) 

180 tl.store(out_index_ptrs, argmax_values, mask=row_mask) 

181 

182 

183def max(inp): 

184 logger.debug("GEMS_MTHREADS MAX") 

185 if not inp.is_contiguous(): 

186 inp = inp.contiguous() 

187 if inp.numel() == 0: 

188 return base_max(inp) 

189 

190 M = inp.numel() 

191 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

192 block_size = builtins.min(block_size * 4, 4096, triton.next_power_of_2(M)) 

193 mid_size = triton.cdiv(M, block_size) 

194 block_mid = triton.next_power_of_2(mid_size) 

195 

196 dtype = inp.dtype 

197 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

198 out = torch.empty([], dtype=dtype, device=inp.device) 

199 

200 num_warps_block = builtins.min(8, builtins.max(1, block_size // 128)) 

201 num_warps_mid = builtins.min(8, builtins.max(1, block_mid // 128)) 

202 

203 with torch_device_fn.device(inp.device): 

204 max_kernel_1[(mid_size, 1, 1)]( 

205 inp, mid, M, block_size, num_warps=num_warps_block, num_stages=2 

206 ) 

207 max_kernel_2[(1, 1, 1)]( 

208 mid, out, mid_size, block_mid, num_warps=num_warps_mid, num_stages=2 

209 ) 

210 return out 

211 

212 

213def max_dim(inp, dim=None, keepdim=False): 

214 logger.debug("GEMS_MTHREADS MAX DIM") 

215 assert dim is not None, "dim must be specified" 

216 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

217 dim = dim % inp.ndim 

218 

219 if not inp.is_contiguous(): 

220 return base_max_dim(inp, dim=dim, keepdim=keepdim) 

221 

222 shape = list(inp.shape) 

223 dim, N, inner, outer = _flatten_dim(shape, dim) 

224 M = outer * inner 

225 stride = inp.stride() 

226 stride_reduce = stride[dim] 

227 stride_outer = stride_reduce * N 

228 

229 out_value = torch.empty((M,), dtype=inp.dtype, device=inp.device) 

230 out_index = torch.empty((M,), dtype=torch.int32, device=inp.device) 

231 

232 if inner == 1 and N <= 128: 

233 block_n = builtins.min(triton.next_power_of_2(N), 128) 

234 grid = (triton.cdiv(M, 1),) 

235 with torch_device_fn.device(inp.device): 

236 max_kernel_small[grid]( 

237 inp, 

238 out_value, 

239 out_index, 

240 M, 

241 N, 

242 stride_outer, 

243 stride_reduce, 

244 block_n, 

245 num_warps=1, 

246 num_stages=1, 

247 ) 

248 else: 

249 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

250 with torch_device_fn.device(inp.device): 

251 max_kernel[grid]( 

252 inp, 

253 out_value, 

254 out_index, 

255 M, 

256 N, 

257 builtins.max(inner, 1), 

258 stride_outer, 

259 stride_reduce, 

260 ) 

261 

262 out_shape = shape.copy() 

263 out_shape[dim] = 1 

264 out_value = out_value.view(out_shape) 

265 out_index = out_index.view(out_shape).to(torch.int64) 

266 if not keepdim: 

267 out_value = torch.squeeze(out_value, dim) 

268 out_index = torch.squeeze(out_index, dim) 

269 

270 return MaxOut(values=out_value, indices=out_index) 

271 

272 

273__all__ = ["max", "max_dim"]