Coverage for src/flag_gems/runtime/backend/_mthreads/ops/prod.py: 0%
188 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
1import logging
2import math
3from typing import Sequence
5import torch
6import triton
7import triton.language as tl
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import dim_compress, libentry
11from flag_gems.utils import triton_lang_extension as tle
13logger = logging.getLogger(
14 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}'
15)
18@triton.jit
19def reduce_mul(a, b):
20 return a * b
23NAIVE_REDUCTION_CONFIGS = [
24 triton.Config({"BLOCK_M": 8, "BLOCK_N": 64}, num_warps=2),
25 triton.Config({"BLOCK_M": 16, "BLOCK_N": 128}, num_warps=2),
26 triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=4),
27 triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=4, num_stages=2),
28 triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_warps=4, num_stages=2),
29 triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_warps=8, num_stages=2),
30 triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=8, num_stages=2),
31 triton.Config({"BLOCK_M": 32, "BLOCK_N": 512}, num_warps=8, num_stages=2),
32 triton.Config({"BLOCK_M": 64, "BLOCK_N": 512}, num_warps=8, num_stages=2),
33 triton.Config({"BLOCK_M": 64, "BLOCK_N": 1024}, num_warps=8, num_stages=2),
34]
37def _prune_reduction_configs(configs, named_args, **meta):
38 """Skip oversized tiles to avoid needless autotune on tiny shapes."""
39 M = named_args["M"]
40 N = named_args["N"]
41 max_block_m = max(M, 8)
42 min_block_m = 8
43 n_cap = 1 << (N - 1).bit_length()
44 n_cap = max(64, min(n_cap, 1024))
45 filtered = [
46 cfg
47 for cfg in configs
48 if min_block_m <= cfg.kwargs["BLOCK_M"] <= max_block_m
49 and cfg.kwargs["BLOCK_N"] <= max(256, n_cap)
50 ]
51 return filtered or configs
54def _flatten_dim(shape: Sequence[int], dim: int):
55 dim = dim % len(shape)
56 n = shape[dim]
57 inner = math.prod(shape[dim + 1 :]) if dim + 1 < len(shape) else 1
58 outer = math.prod(shape[:dim]) if dim > 0 else 1
59 return dim, n, inner, outer
62def _reshape_output(out: torch.Tensor, shape: list[int], dim: int, keepdim: bool):
63 out_shape = shape.copy()
64 out_shape[dim] = 1
65 out_view = out.view(out_shape)
66 if not keepdim:
67 out_view = torch.squeeze(out_view, dim)
68 return out_view
71@libentry()
72@triton.jit
73def prod_kernel_mid(
74 inp,
75 mid,
76 M,
77 BLOCK_SIZE: tl.constexpr,
78):
79 dtype = inp.type.element_ty
80 acc_dtype = tl.float32
81 pid = tle.program_id(0)
82 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
83 inp_ptrs = inp + offset
84 mask = offset < M
85 inp_val = tl.load(inp_ptrs, mask=mask, other=1.0).to(acc_dtype)
86 mid_value = tl.reduce(inp_val, axis=0, combine_fn=reduce_mul).to(dtype)
87 mid_ptr = mid + pid
88 tl.store(mid_ptr, mid_value)
91@libentry()
92@triton.jit
93def prod_kernel_result(mid, out, mid_size, BLOCK_MID: tl.constexpr):
94 dtype = mid.type.element_ty
95 acc_dtype = tl.float32
96 offset = tl.arange(0, BLOCK_MID)
97 mid_ptrs = mid + offset
98 mask = offset < mid_size
99 mid_val = tl.load(mid_ptrs, mask=mask, other=1.0).to(acc_dtype)
100 prod_val = tl.reduce(mid_val, axis=0, combine_fn=reduce_mul).to(dtype)
101 tl.store(out, prod_val)
104@triton.jit
105def prod_kernel_dim_64(
106 inp,
107 out,
108 M,
109 INNER,
110 STRIDE_OUTER,
111 BLOCK_M: tl.constexpr,
112):
113 pid = tle.program_id(0)
114 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)
115 row_mask = rows < M
116 base_ptr = inp + rows * STRIDE_OUTER
117 cols = tl.arange(0, 64)
118 vals = tl.load(base_ptr[:, None] + cols[None, :], cache_modifier=".cg")
119 prod_vals = tl.reduce(vals.to(tl.float32), axis=1, combine_fn=reduce_mul)
120 tl.store(out + rows, prod_vals.to(inp.type.element_ty), mask=row_mask)
123@triton.jit
124def prod_kernel_dim_contig(
125 inp,
126 out,
127 M,
128 INNER,
129 STRIDE_OUTER,
130 BLOCK_M: tl.constexpr,
131 BLOCK_N: tl.constexpr,
132):
133 pid = tle.program_id(0)
134 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)
135 row_mask = rows < M
136 base_ptr = inp + rows * STRIDE_OUTER
137 cols = tl.arange(0, BLOCK_N)
138 col_mask = cols[None, :] < STRIDE_OUTER
139 mask = row_mask[:, None] & col_mask
140 vals = tl.load(
141 base_ptr[:, None] + cols[None, :],
142 mask=mask,
143 other=1.0,
144 cache_modifier=".cg",
145 )
146 prod_vals = tl.reduce(vals.to(tl.float32), axis=1, combine_fn=reduce_mul)
147 tl.store(out + rows, prod_vals.to(inp.type.element_ty), mask=row_mask)
150@triton.jit
151def prod_kernel_dim_dense(
152 inp,
153 out,
154 M,
155 N,
156 INNER,
157 STRIDE_OUTER,
158 STRIDE_REDUCE,
159 BLOCK_M: tl.constexpr,
160 BLOCK_N: tl.constexpr,
161):
162 dtype = inp.type.element_ty
163 acc_dtype = tl.float32
164 pid = tle.program_id(0)
165 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)
166 outer_idx = rows // INNER
167 inner_idx = rows % INNER
168 base_ptr = inp + outer_idx * STRIDE_OUTER + inner_idx
170 acc = tl.full((BLOCK_M,), value=1.0, dtype=acc_dtype)
171 for off in range(0, N, BLOCK_N):
172 cols = off + tl.arange(0, BLOCK_N)
173 vals = tl.load(
174 base_ptr[:, None] + cols[None, :] * STRIDE_REDUCE,
175 cache_modifier=".cg",
176 ).to(acc_dtype)
177 chunk_prod = tl.reduce(vals, axis=1, combine_fn=reduce_mul)
178 acc *= chunk_prod
180 tl.store(out + rows, acc.to(dtype))
183@triton.autotune(
184 configs=NAIVE_REDUCTION_CONFIGS,
185 key=["M", "N"],
186 prune_configs_by={"early_config_prune": _prune_reduction_configs},
187 warmup=2,
188 rep=8,
189)
190@triton.jit
191def prod_kernel_dim(
192 inp,
193 out,
194 M,
195 N,
196 INNER,
197 STRIDE_OUTER,
198 STRIDE_REDUCE,
199 BLOCK_M: tl.constexpr,
200 BLOCK_N: tl.constexpr,
201):
202 dtype = inp.type.element_ty
203 acc_dtype = tl.float32
204 pid = tle.program_id(0)
205 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)
206 rows = rows.to(tl.int64)
207 row_mask = rows < M
209 outer_idx = rows // INNER
210 inner_idx = rows % INNER
211 base_ptr = inp + outer_idx * STRIDE_OUTER + inner_idx
213 acc = tl.full((BLOCK_M,), value=1.0, dtype=acc_dtype)
214 for off in range(0, N, BLOCK_N):
215 cols = off + tl.arange(0, BLOCK_N)
216 cols = cols.to(tl.int64)
217 col_mask = cols < N
218 mask = row_mask[:, None] & col_mask[None, :]
219 vals = tl.load(
220 base_ptr[:, None] + cols[None, :] * STRIDE_REDUCE,
221 mask=mask,
222 other=1.0,
223 cache_modifier=".cg",
224 ).to(acc_dtype)
225 chunk_prod = tl.reduce(vals, axis=1, combine_fn=reduce_mul)
226 acc *= chunk_prod
228 out_ptrs = out + rows
229 tl.store(out_ptrs, acc.to(dtype), mask=row_mask)
232def prod(inp, *, dtype=None):
233 logger.debug("GEMS_MTHREADS PROD")
234 if dtype is None:
235 dtype = inp.dtype
236 if not inp.is_contiguous():
237 inp = inp.contiguous()
239 M = inp.numel()
240 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
241 block_size = min(block_size * 2, 4096, triton.next_power_of_2(M))
242 mid_size = triton.cdiv(M, block_size)
243 block_mid = triton.next_power_of_2(mid_size)
245 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)
246 out = torch.empty([], dtype=dtype, device=inp.device)
248 with torch_device_fn.device(inp.device):
249 prod_kernel_mid[(mid_size, 1, 1)](inp, mid, M, block_size)
250 prod_kernel_result[(1, 1, 1)](mid, out, mid_size, block_mid)
251 return out
254def prod_dim(inp, dim=None, keepdim=False, *, dtype=None):
255 logger.debug("GEMS_MTHREADS PROD DIM")
256 assert dim is not None, "dim must be specified"
257 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
258 dim = dim % inp.ndim
260 if dtype is None:
261 dtype = inp.dtype
262 if not inp.is_contiguous():
263 inp = dim_compress(inp, dim)
264 dim = inp.ndim - 1
266 shape = list(inp.shape)
267 dim, n, inner, outer = _flatten_dim(shape, dim)
268 m = outer * inner
270 out_flat = torch.empty((m,), dtype=dtype, device=inp.device)
272 stride = inp.stride()
273 stride_reduce = stride[dim]
274 stride_outer = stride_reduce * n
276 if n == 64 and stride_reduce == 1 and stride_outer == n:
277 grid_64 = (triton.cdiv(m, 8),)
278 with torch_device_fn.device(inp.device):
279 prod_kernel_dim_64[grid_64](
280 inp, out_flat, m, inner, stride_outer, BLOCK_M=8, num_warps=2
281 )
282 return _reshape_output(out_flat, shape, dim, keepdim)
284 key = (m, n, str(dtype), str(out_flat.dtype))
285 config = prod_kernel_dim.cache.get(key, None)
286 if m * n >= 64 * 1024 * 1024 and config is None:
287 if dtype in (torch.float16, torch.bfloat16):
288 config = triton.Config(
289 {"BLOCK_M": 64, "BLOCK_N": 256}, num_warps=8, num_stages=2
290 )
291 else:
292 config = triton.Config(
293 {"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=4, num_stages=1
294 )
295 prod_kernel_dim.cache[key] = config
297 if config is not None:
298 block_m_cfg = config.kwargs["BLOCK_M"]
299 block_n_cfg = config.kwargs["BLOCK_N"]
300 if m % block_m_cfg == 0 and n % block_n_cfg == 0:
301 grid_dense = (m // block_m_cfg,)
302 with torch_device_fn.device(inp.device):
303 prod_kernel_dim_dense[grid_dense](
304 inp,
305 out_flat,
306 m,
307 n,
308 inner,
309 stride_outer,
310 stride_reduce,
311 BLOCK_M=block_m_cfg,
312 BLOCK_N=block_n_cfg,
313 num_warps=config.num_warps or 4,
314 num_stages=config.num_stages or 1,
315 )
316 return _reshape_output(out_flat, shape, dim, keepdim)
318 if stride_reduce == 1 and stride_outer == n and n <= 1024:
319 block_m = 128 if n >= 256 else 64
320 block_n = min(512, max(64, 1 << (n - 1).bit_length()))
321 grid_contig = (triton.cdiv(m, block_m),)
322 with torch_device_fn.device(inp.device):
323 prod_kernel_dim_contig[grid_contig](
324 inp,
325 out_flat,
326 m,
327 inner,
328 stride_outer,
329 BLOCK_M=block_m,
330 BLOCK_N=block_n,
331 num_warps=8 if n >= 256 else 4,
332 num_stages=2,
333 )
334 return _reshape_output(out_flat, shape, dim, keepdim)
336 if n <= 64:
337 prod_kernel_dim.cache[key] = triton.Config(
338 {"BLOCK_M": 8, "BLOCK_N": 64}, num_warps=2, num_stages=1
339 )
341 grid = lambda meta: (triton.cdiv(m, meta["BLOCK_M"]),)
342 with torch_device_fn.device(inp.device):
343 prod_kernel_dim[grid](
344 inp,
345 out_flat,
346 m,
347 n,
348 max(inner, 1),
349 stride_outer,
350 stride_reduce,
351 )
353 return _reshape_output(out_flat, shape, dim, keepdim)