Coverage for src/flag_gems/ops/prod.py: 64%
86 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import dim_compress, libentry, libtuner
11from flag_gems.utils import triton_lang_extension as tle
13logger = logging.getLogger(__name__)
16@triton.jit
17def reduce_mul(a, b):
18 return a * b
21@libentry()
22@triton.jit
23def prod_kernel_mid(
24 inp,
25 mid,
26 M,
27 BLOCK_SIZE: tl.constexpr,
28):
29 pid = tle.program_id(0)
30 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
31 inp_ptrs = inp + offset
32 mask = offset < M
33 inp_val = tl.load(inp_ptrs, mask=mask, other=1.0).to(tl.float32)
34 mid_value = tl.reduce(inp_val, axis=0, combine_fn=reduce_mul)
35 mid_ptr = mid + pid
36 tl.store(mid_ptr, mid_value.to(inp_val.dtype))
39@libentry()
40@triton.jit
41def prod_kernel_result(mid, out, mid_size, BLOCK_MID: tl.constexpr):
42 offset = tl.arange(0, BLOCK_MID)
43 mid_ptrs = mid + offset
44 mask = offset < mid_size
45 mid_val = tl.load(mid_ptrs, mask=mask, other=1.0).to(tl.float32)
46 prod_val = tl.reduce(mid_val, axis=0, combine_fn=reduce_mul)
47 tl.store(out, prod_val)
50def prod(inp, *, dtype=None):
51 logger.debug("GEMS PROD")
52 if dtype is None:
53 dtype = inp.dtype
55 M = inp.numel()
56 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
57 mid_size = triton.cdiv(M, block_size)
58 block_mid = triton.next_power_of_2(mid_size)
60 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)
61 out = torch.empty([], dtype=dtype, device=inp.device)
63 with torch_device_fn.device(inp.device):
64 prod_kernel_mid[(mid_size, 1, 1)](inp, mid, M, block_size)
65 prod_kernel_result[(1, 1, 1)](mid, out, mid_size, block_mid)
66 return out
69def heur_block_n(args):
70 return triton.next_power_of_2(args["N"])
73@libentry()
74@libtuner(
75 configs=runtime.get_tuned_config("naive_reduction"),
76 key=["M", "N"],
77)
78@triton.jit
79def prod_kernel(
80 inp,
81 out,
82 M,
83 N,
84 BLOCK_M: tl.constexpr,
85 BLOCK_N: tl.constexpr,
86):
87 # set offset
88 pid_m = tle.program_id(0)
89 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
91 acc = tl.full((BLOCK_M, BLOCK_N), value=1.0, dtype=tl.float32)
92 for start_n in range(0, N, BLOCK_N):
93 n_offset = start_n + tl.arange(0, BLOCK_N)
94 offset = m_offset[:, None] * N + n_offset[None, :]
96 # set mask
97 mask = (m_offset[:, None] < M) & (n_offset[None, :] < N)
98 inp_ptrs = inp + offset
99 inp_vals = tl.load(inp_ptrs, mask=mask, other=1.0).to(tl.float32)
100 acc *= inp_vals
101 result_index = tl.reduce(acc, axis=1, combine_fn=reduce_mul)
103 offset_index = m_offset
104 out_ptrs = out + offset_index
105 mask1 = m_offset < M
106 tl.store(out_ptrs, result_index, mask=mask1)
109def prod_dim(inp, dim=None, keepdim=False, *, dtype=None):
110 logger.debug("GEMS PROD DIM")
112 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
113 shape = list(inp.shape)
114 dim = dim % inp.ndim
115 inp = dim_compress(inp, dim)
116 N = shape[dim]
117 shape[dim] = 1
118 M = inp.numel() // N
120 if dtype is None:
121 dtype = inp.dtype
122 out = torch.empty(shape, dtype=dtype, device=inp.device)
123 if not keepdim:
124 out = torch.squeeze(out, dim)
126 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
127 with torch_device_fn.device(inp.device):
128 prod_kernel[grid](inp, out, M, N)
130 return out