Coverage for src/flag_gems/runtime/backend/_mthreads/ops/amax.py: 0%

143 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1import builtins 

2import logging 

3import math 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems.ops.amax import amax as base_amax 

10from flag_gems.runtime import torch_device_fn 

11from flag_gems.utils import libentry 

12from flag_gems.utils import triton_lang_extension as tle 

13from flag_gems.utils.limits import get_dtype_min 

14 

15logger = logging.getLogger( 

16 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}' 

17) 

18 

19AMAX_REDUCTION_CONFIGS = [ 

20 triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=4, num_stages=1), 

21 triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=8, num_stages=1), 

22 triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=8, num_stages=2), 

23 triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_warps=8, num_stages=2), 

24 triton.Config({"BLOCK_M": 16, "BLOCK_N": 256}, num_warps=4, num_stages=2), 

25 triton.Config({"BLOCK_M": 16, "BLOCK_N": 512}, num_warps=8, num_stages=2), 

26 triton.Config({"BLOCK_M": 32, "BLOCK_N": 512}, num_warps=8, num_stages=2), 

27 triton.Config({"BLOCK_M": 16, "BLOCK_N": 1024}, num_warps=8, num_stages=2), 

28 triton.Config({"BLOCK_M": 32, "BLOCK_N": 1024}, num_warps=8, num_stages=2), 

29 triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1), 

30] 

31 

32 

33def _prune_reduction_configs(configs, nargs, **meta): 

34 n = meta.get("N", nargs["N"]) 

35 if n <= 128: 

36 max_block_n = 128 

37 elif n <= 2048: 

38 max_block_n = 256 

39 elif n <= 8192: 

40 max_block_n = 512 

41 else: 

42 max_block_n = 1024 

43 return [cfg for cfg in configs if cfg.kwargs["BLOCK_N"] <= max_block_n] 

44 

45 

46def _flatten_dim(shape, dim): 

47 dim = dim % len(shape) 

48 n = shape[dim] 

49 inner = math.prod(shape[dim + 1 :]) if dim + 1 < len(shape) else 1 

50 outer = math.prod(shape[:dim]) if dim > 0 else 1 

51 return dim, n, inner, outer 

52 

53 

54@libentry() 

55@triton.jit 

56def amax_kernel_1( 

57 inp, 

58 mid, 

59 M, 

60 BLOCK_SIZE: tl.constexpr, 

61): 

62 pid = tle.program_id(0) 

63 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

64 mask = offset < M 

65 min_value = get_dtype_min(inp.type.element_ty) 

66 vals = tl.load(inp + offset, mask=mask, other=min_value, cache_modifier=".cg") 

67 tl.store(mid + pid, tl.max(vals)) 

68 

69 

70@libentry() 

71@triton.jit 

72def amax_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr): 

73 offset = tl.arange(0, BLOCK_MID) 

74 mask = offset < mid_size 

75 min_value = get_dtype_min(mid.type.element_ty) 

76 vals = tl.load(mid + offset, mask=mask, other=min_value) 

77 tl.store(out, tl.max(vals)) 

78 

79 

80@libentry() 

81@triton.jit 

82def amax_kernel_small( 

83 inp, 

84 out_value, 

85 M, 

86 N, 

87 STRIDE_OUTER, 

88 STRIDE_REDUCE, 

89 BLOCK_N: tl.constexpr, 

90): 

91 row = tle.program_id(0) 

92 row_mask = row < M 

93 cols = tl.arange(0, BLOCK_N) 

94 col_mask = cols < N 

95 

96 stride_outer = tl.full((), STRIDE_OUTER, tl.int64) 

97 stride_reduce = tl.full((), STRIDE_REDUCE, tl.int64) 

98 offsets = row.to(tl.int64) * stride_outer + cols.to(tl.int64) * stride_reduce 

99 

100 dtype = inp.type.element_ty 

101 acc_type = tl.float32 if (dtype is tl.float16 or dtype is tl.bfloat16) else dtype 

102 min_value = get_dtype_min(dtype) 

103 vals = tl.load(inp + offsets, mask=row_mask & col_mask, other=min_value).to( 

104 acc_type 

105 ) 

106 row_max = tl.max(vals, axis=0) 

107 tl.store(out_value + row, row_max, mask=row_mask) 

108 

109 

110@libentry() 

111@triton.autotune( 

112 configs=AMAX_REDUCTION_CONFIGS, 

113 key=["M", "N"], 

114 warmup=8, 

115 rep=40, 

116 prune_configs_by={"early_config_prune": _prune_reduction_configs}, 

117) 

118@triton.jit 

119def amax_kernel( 

120 inp, 

121 out_value, 

122 M, 

123 N, 

124 INNER, 

125 STRIDE_OUTER, 

126 STRIDE_REDUCE, 

127 BLOCK_M: tl.constexpr, 

128 BLOCK_N: tl.constexpr, 

129): 

130 pid_m = tle.program_id(0) 

131 rows = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

132 rows = rows.to(tl.int64) 

133 row_mask = rows < M 

134 

135 outer_idx = rows // INNER 

136 inner_idx = rows % INNER 

137 base_ptr = inp + outer_idx * STRIDE_OUTER + inner_idx 

138 

139 dtype = inp.type.element_ty 

140 acc_type = tl.float32 if (dtype is tl.float16 or dtype is tl.bfloat16) else dtype 

141 min_value = get_dtype_min(dtype) 

142 max_values = tl.full([BLOCK_M], dtype=acc_type, value=min_value) 

143 

144 for start_n in range(0, N, BLOCK_N): 

145 n_offset = start_n + tl.arange(0, BLOCK_N) 

146 n_offset = n_offset.to(tl.int64) 

147 mask = row_mask[:, None] & (n_offset[None, :] < N) 

148 inp_ptrs = base_ptr[:, None] + n_offset[None, :] * STRIDE_REDUCE 

149 inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value, cache_modifier=".cg") 

150 inp_vals = inp_vals.to(acc_type) 

151 local_max = tl.max(inp_vals, axis=1) 

152 max_values = tl.maximum(max_values, local_max) 

153 

154 out_value_ptrs = out_value + rows 

155 tl.store(out_value_ptrs, max_values, mask=row_mask) 

156 

157 

158def amax(inp, dim=None, keepdim=False): 

159 logger.debug("GEMS_MTHREADS AMAX") 

160 

161 if dim is None or (isinstance(dim, (list, tuple)) and len(dim) == 0): 

162 # Global reduction 

163 if not inp.is_contiguous(): 

164 inp = inp.contiguous() 

165 if inp.numel() == 0: 

166 return base_amax(inp, dim=dim, keepdim=keepdim) 

167 

168 M = inp.numel() 

169 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

170 block_size = builtins.min(block_size * 4, 4096, triton.next_power_of_2(M)) 

171 mid_size = triton.cdiv(M, block_size) 

172 block_mid = triton.next_power_of_2(mid_size) 

173 

174 dtype = inp.dtype 

175 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

176 

177 if not keepdim: 

178 out = torch.empty([], dtype=dtype, device=inp.device) 

179 else: 

180 shape = [1] * inp.dim() 

181 out = torch.empty(shape, dtype=dtype, device=inp.device) 

182 

183 num_warps_block = builtins.min(8, builtins.max(1, block_size // 128)) 

184 num_warps_mid = builtins.min(8, builtins.max(1, block_mid // 128)) 

185 

186 with torch_device_fn.device(inp.device): 

187 amax_kernel_1[(mid_size, 1, 1)]( 

188 inp, mid, M, block_size, num_warps=num_warps_block, num_stages=2 

189 ) 

190 amax_kernel_2[(1, 1, 1)]( 

191 mid, out, mid_size, block_mid, num_warps=num_warps_mid, num_stages=2 

192 ) 

193 return out 

194 else: 

195 # Dimension-specific reduction 

196 if isinstance(dim, int): 

197 dim = [dim] 

198 

199 # For multi-dim reduction, use base implementation 

200 if len(dim) > 1: 

201 return base_amax(inp, dim=dim, keepdim=keepdim) 

202 

203 dim_val = dim[0] 

204 assert dim_val >= -inp.ndim and dim_val < inp.ndim, "Invalid dim" 

205 dim_val = dim_val % inp.ndim 

206 

207 if not inp.is_contiguous(): 

208 return base_amax(inp, dim=dim, keepdim=keepdim) 

209 

210 shape = list(inp.shape) 

211 dim_val, N, inner, outer = _flatten_dim(shape, dim_val) 

212 M = outer * inner 

213 stride = inp.stride() 

214 stride_reduce = stride[dim_val] 

215 stride_outer = stride_reduce * N 

216 

217 out_value = torch.empty((M,), dtype=inp.dtype, device=inp.device) 

218 

219 if inner == 1 and N <= 128: 

220 block_n = builtins.min(triton.next_power_of_2(N), 128) 

221 grid = (triton.cdiv(M, 1),) 

222 with torch_device_fn.device(inp.device): 

223 amax_kernel_small[grid]( 

224 inp, 

225 out_value, 

226 M, 

227 N, 

228 stride_outer, 

229 stride_reduce, 

230 block_n, 

231 num_warps=1, 

232 num_stages=1, 

233 ) 

234 else: 

235 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

236 with torch_device_fn.device(inp.device): 

237 amax_kernel[grid]( 

238 inp, 

239 out_value, 

240 M, 

241 N, 

242 builtins.max(inner, 1), 

243 stride_outer, 

244 stride_reduce, 

245 ) 

246 

247 out_shape = shape.copy() 

248 out_shape[dim_val] = 1 

249 out_value = out_value.view(out_shape) 

250 if not keepdim: 

251 out_value = torch.squeeze(out_value, dim_val) 

252 

253 return out_value 

254 

255 

256__all__ = ["amax"]