Coverage for src/flag_gems/runtime/backend/_mthreads/ops/prod.py: 0%

188 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-11 02:28 +0800

1import logging 

2import math 

3from typing import Sequence 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import dim_compress, libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13logger = logging.getLogger( 

14 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}' 

15) 

16 

17 

18@triton.jit 

19def reduce_mul(a, b): 

20 return a * b 

21 

22 

23NAIVE_REDUCTION_CONFIGS = [ 

24 triton.Config({"BLOCK_M": 8, "BLOCK_N": 64}, num_warps=2), 

25 triton.Config({"BLOCK_M": 16, "BLOCK_N": 128}, num_warps=2), 

26 triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=4), 

27 triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=4, num_stages=2), 

28 triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_warps=4, num_stages=2), 

29 triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_warps=8, num_stages=2), 

30 triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=8, num_stages=2), 

31 triton.Config({"BLOCK_M": 32, "BLOCK_N": 512}, num_warps=8, num_stages=2), 

32 triton.Config({"BLOCK_M": 64, "BLOCK_N": 512}, num_warps=8, num_stages=2), 

33 triton.Config({"BLOCK_M": 64, "BLOCK_N": 1024}, num_warps=8, num_stages=2), 

34] 

35 

36 

37def _prune_reduction_configs(configs, named_args, **meta): 

38 """Skip oversized tiles to avoid needless autotune on tiny shapes.""" 

39 M = named_args["M"] 

40 N = named_args["N"] 

41 max_block_m = max(M, 8) 

42 min_block_m = 8 

43 n_cap = 1 << (N - 1).bit_length() 

44 n_cap = max(64, min(n_cap, 1024)) 

45 filtered = [ 

46 cfg 

47 for cfg in configs 

48 if min_block_m <= cfg.kwargs["BLOCK_M"] <= max_block_m 

49 and cfg.kwargs["BLOCK_N"] <= max(256, n_cap) 

50 ] 

51 return filtered or configs 

52 

53 

54def _flatten_dim(shape: Sequence[int], dim: int): 

55 dim = dim % len(shape) 

56 n = shape[dim] 

57 inner = math.prod(shape[dim + 1 :]) if dim + 1 < len(shape) else 1 

58 outer = math.prod(shape[:dim]) if dim > 0 else 1 

59 return dim, n, inner, outer 

60 

61 

62def _reshape_output(out: torch.Tensor, shape: list[int], dim: int, keepdim: bool): 

63 out_shape = shape.copy() 

64 out_shape[dim] = 1 

65 out_view = out.view(out_shape) 

66 if not keepdim: 

67 out_view = torch.squeeze(out_view, dim) 

68 return out_view 

69 

70 

71@libentry() 

72@triton.jit 

73def prod_kernel_mid( 

74 inp, 

75 mid, 

76 M, 

77 BLOCK_SIZE: tl.constexpr, 

78): 

79 dtype = inp.type.element_ty 

80 acc_dtype = tl.float32 

81 pid = tle.program_id(0) 

82 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

83 inp_ptrs = inp + offset 

84 mask = offset < M 

85 inp_val = tl.load(inp_ptrs, mask=mask, other=1.0).to(acc_dtype) 

86 mid_value = tl.reduce(inp_val, axis=0, combine_fn=reduce_mul).to(dtype) 

87 mid_ptr = mid + pid 

88 tl.store(mid_ptr, mid_value) 

89 

90 

91@libentry() 

92@triton.jit 

93def prod_kernel_result(mid, out, mid_size, BLOCK_MID: tl.constexpr): 

94 dtype = mid.type.element_ty 

95 acc_dtype = tl.float32 

96 offset = tl.arange(0, BLOCK_MID) 

97 mid_ptrs = mid + offset 

98 mask = offset < mid_size 

99 mid_val = tl.load(mid_ptrs, mask=mask, other=1.0).to(acc_dtype) 

100 prod_val = tl.reduce(mid_val, axis=0, combine_fn=reduce_mul).to(dtype) 

101 tl.store(out, prod_val) 

102 

103 

104@triton.jit 

105def prod_kernel_dim_64( 

106 inp, 

107 out, 

108 M, 

109 INNER, 

110 STRIDE_OUTER, 

111 BLOCK_M: tl.constexpr, 

112): 

113 pid = tle.program_id(0) 

114 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M) 

115 row_mask = rows < M 

116 base_ptr = inp + rows * STRIDE_OUTER 

117 cols = tl.arange(0, 64) 

118 vals = tl.load(base_ptr[:, None] + cols[None, :], cache_modifier=".cg") 

119 prod_vals = tl.reduce(vals.to(tl.float32), axis=1, combine_fn=reduce_mul) 

120 tl.store(out + rows, prod_vals.to(inp.type.element_ty), mask=row_mask) 

121 

122 

123@triton.jit 

124def prod_kernel_dim_contig( 

125 inp, 

126 out, 

127 M, 

128 INNER, 

129 STRIDE_OUTER, 

130 BLOCK_M: tl.constexpr, 

131 BLOCK_N: tl.constexpr, 

132): 

133 pid = tle.program_id(0) 

134 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M) 

135 row_mask = rows < M 

136 base_ptr = inp + rows * STRIDE_OUTER 

137 cols = tl.arange(0, BLOCK_N) 

138 col_mask = cols[None, :] < STRIDE_OUTER 

139 mask = row_mask[:, None] & col_mask 

140 vals = tl.load( 

141 base_ptr[:, None] + cols[None, :], 

142 mask=mask, 

143 other=1.0, 

144 cache_modifier=".cg", 

145 ) 

146 prod_vals = tl.reduce(vals.to(tl.float32), axis=1, combine_fn=reduce_mul) 

147 tl.store(out + rows, prod_vals.to(inp.type.element_ty), mask=row_mask) 

148 

149 

150@triton.jit 

151def prod_kernel_dim_dense( 

152 inp, 

153 out, 

154 M, 

155 N, 

156 INNER, 

157 STRIDE_OUTER, 

158 STRIDE_REDUCE, 

159 BLOCK_M: tl.constexpr, 

160 BLOCK_N: tl.constexpr, 

161): 

162 dtype = inp.type.element_ty 

163 acc_dtype = tl.float32 

164 pid = tle.program_id(0) 

165 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M) 

166 outer_idx = rows // INNER 

167 inner_idx = rows % INNER 

168 base_ptr = inp + outer_idx * STRIDE_OUTER + inner_idx 

169 

170 acc = tl.full((BLOCK_M,), value=1.0, dtype=acc_dtype) 

171 for off in range(0, N, BLOCK_N): 

172 cols = off + tl.arange(0, BLOCK_N) 

173 vals = tl.load( 

174 base_ptr[:, None] + cols[None, :] * STRIDE_REDUCE, 

175 cache_modifier=".cg", 

176 ).to(acc_dtype) 

177 chunk_prod = tl.reduce(vals, axis=1, combine_fn=reduce_mul) 

178 acc *= chunk_prod 

179 

180 tl.store(out + rows, acc.to(dtype)) 

181 

182 

183@triton.autotune( 

184 configs=NAIVE_REDUCTION_CONFIGS, 

185 key=["M", "N"], 

186 prune_configs_by={"early_config_prune": _prune_reduction_configs}, 

187 warmup=2, 

188 rep=8, 

189) 

190@triton.jit 

191def prod_kernel_dim( 

192 inp, 

193 out, 

194 M, 

195 N, 

196 INNER, 

197 STRIDE_OUTER, 

198 STRIDE_REDUCE, 

199 BLOCK_M: tl.constexpr, 

200 BLOCK_N: tl.constexpr, 

201): 

202 dtype = inp.type.element_ty 

203 acc_dtype = tl.float32 

204 pid = tle.program_id(0) 

205 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M) 

206 rows = rows.to(tl.int64) 

207 row_mask = rows < M 

208 

209 outer_idx = rows // INNER 

210 inner_idx = rows % INNER 

211 base_ptr = inp + outer_idx * STRIDE_OUTER + inner_idx 

212 

213 acc = tl.full((BLOCK_M,), value=1.0, dtype=acc_dtype) 

214 for off in range(0, N, BLOCK_N): 

215 cols = off + tl.arange(0, BLOCK_N) 

216 cols = cols.to(tl.int64) 

217 col_mask = cols < N 

218 mask = row_mask[:, None] & col_mask[None, :] 

219 vals = tl.load( 

220 base_ptr[:, None] + cols[None, :] * STRIDE_REDUCE, 

221 mask=mask, 

222 other=1.0, 

223 cache_modifier=".cg", 

224 ).to(acc_dtype) 

225 chunk_prod = tl.reduce(vals, axis=1, combine_fn=reduce_mul) 

226 acc *= chunk_prod 

227 

228 out_ptrs = out + rows 

229 tl.store(out_ptrs, acc.to(dtype), mask=row_mask) 

230 

231 

232def prod(inp, *, dtype=None): 

233 logger.debug("GEMS_MTHREADS PROD") 

234 if dtype is None: 

235 dtype = inp.dtype 

236 if not inp.is_contiguous(): 

237 inp = inp.contiguous() 

238 

239 M = inp.numel() 

240 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

241 block_size = min(block_size * 2, 4096, triton.next_power_of_2(M)) 

242 mid_size = triton.cdiv(M, block_size) 

243 block_mid = triton.next_power_of_2(mid_size) 

244 

245 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

246 out = torch.empty([], dtype=dtype, device=inp.device) 

247 

248 with torch_device_fn.device(inp.device): 

249 prod_kernel_mid[(mid_size, 1, 1)](inp, mid, M, block_size) 

250 prod_kernel_result[(1, 1, 1)](mid, out, mid_size, block_mid) 

251 return out 

252 

253 

254def prod_dim(inp, dim=None, keepdim=False, *, dtype=None): 

255 logger.debug("GEMS_MTHREADS PROD DIM") 

256 assert dim is not None, "dim must be specified" 

257 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

258 dim = dim % inp.ndim 

259 

260 if dtype is None: 

261 dtype = inp.dtype 

262 if not inp.is_contiguous(): 

263 inp = dim_compress(inp, dim) 

264 dim = inp.ndim - 1 

265 

266 shape = list(inp.shape) 

267 dim, n, inner, outer = _flatten_dim(shape, dim) 

268 m = outer * inner 

269 

270 out_flat = torch.empty((m,), dtype=dtype, device=inp.device) 

271 

272 stride = inp.stride() 

273 stride_reduce = stride[dim] 

274 stride_outer = stride_reduce * n 

275 

276 if n == 64 and stride_reduce == 1 and stride_outer == n: 

277 grid_64 = (triton.cdiv(m, 8),) 

278 with torch_device_fn.device(inp.device): 

279 prod_kernel_dim_64[grid_64]( 

280 inp, out_flat, m, inner, stride_outer, BLOCK_M=8, num_warps=2 

281 ) 

282 return _reshape_output(out_flat, shape, dim, keepdim) 

283 

284 key = (m, n, str(dtype), str(out_flat.dtype)) 

285 config = prod_kernel_dim.cache.get(key, None) 

286 if m * n >= 64 * 1024 * 1024 and config is None: 

287 if dtype in (torch.float16, torch.bfloat16): 

288 config = triton.Config( 

289 {"BLOCK_M": 64, "BLOCK_N": 256}, num_warps=8, num_stages=2 

290 ) 

291 else: 

292 config = triton.Config( 

293 {"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=4, num_stages=1 

294 ) 

295 prod_kernel_dim.cache[key] = config 

296 

297 if config is not None: 

298 block_m_cfg = config.kwargs["BLOCK_M"] 

299 block_n_cfg = config.kwargs["BLOCK_N"] 

300 if m % block_m_cfg == 0 and n % block_n_cfg == 0: 

301 grid_dense = (m // block_m_cfg,) 

302 with torch_device_fn.device(inp.device): 

303 prod_kernel_dim_dense[grid_dense]( 

304 inp, 

305 out_flat, 

306 m, 

307 n, 

308 inner, 

309 stride_outer, 

310 stride_reduce, 

311 BLOCK_M=block_m_cfg, 

312 BLOCK_N=block_n_cfg, 

313 num_warps=config.num_warps or 4, 

314 num_stages=config.num_stages or 1, 

315 ) 

316 return _reshape_output(out_flat, shape, dim, keepdim) 

317 

318 if stride_reduce == 1 and stride_outer == n and n <= 1024: 

319 block_m = 128 if n >= 256 else 64 

320 block_n = min(512, max(64, 1 << (n - 1).bit_length())) 

321 grid_contig = (triton.cdiv(m, block_m),) 

322 with torch_device_fn.device(inp.device): 

323 prod_kernel_dim_contig[grid_contig]( 

324 inp, 

325 out_flat, 

326 m, 

327 inner, 

328 stride_outer, 

329 BLOCK_M=block_m, 

330 BLOCK_N=block_n, 

331 num_warps=8 if n >= 256 else 4, 

332 num_stages=2, 

333 ) 

334 return _reshape_output(out_flat, shape, dim, keepdim) 

335 

336 if n <= 64: 

337 prod_kernel_dim.cache[key] = triton.Config( 

338 {"BLOCK_M": 8, "BLOCK_N": 64}, num_warps=2, num_stages=1 

339 ) 

340 

341 grid = lambda meta: (triton.cdiv(m, meta["BLOCK_M"]),) 

342 with torch_device_fn.device(inp.device): 

343 prod_kernel_dim[grid]( 

344 inp, 

345 out_flat, 

346 m, 

347 n, 

348 max(inner, 1), 

349 stride_outer, 

350 stride_reduce, 

351 ) 

352 

353 return _reshape_output(out_flat, shape, dim, keepdim)