Coverage for src/flag_gems/runtime/backend/_metax/ops/prod.py: 0%
90 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry, libtuner
11from flag_gems.utils import triton_lang_extension as tle
13logger = logging.getLogger("flag_gems." + __name__)
16@triton.jit
17def reduce_mul(a, b):
18 return a * b
21@libentry()
22@triton.jit
23def prod_kernel_mid(
24 inp,
25 mid,
26 M,
27 BLOCK_SIZE: tl.constexpr,
28):
29 pid = tle.program_id(0)
30 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
31 inp_ptrs = inp + offset
32 mask = offset < M
33 inp_val = tl.load(inp_ptrs, mask=mask, other=1.0).to(tl.float32)
34 mid_value = tl.reduce(inp_val, axis=0, combine_fn=reduce_mul)
35 mid_ptr = mid + pid
36 tl.store(mid_ptr, mid_value.to(inp_val.dtype))
39@libentry()
40@triton.jit
41def prod_kernel_result(mid, out, mid_size, BLOCK_MID: tl.constexpr):
42 offset = tl.arange(0, BLOCK_MID)
43 mid_ptrs = mid + offset
44 mask = offset < mid_size
45 mid_val = tl.load(mid_ptrs, mask=mask, other=1.0).to(tl.float32)
46 prod_val = tl.reduce(mid_val, axis=0, combine_fn=reduce_mul)
47 tl.store(out, prod_val)
50def prod(inp, *, dtype=None):
51 logger.debug("METAX GEMS PROD")
52 if dtype is None:
53 dtype = inp.dtype
55 M = inp.numel()
56 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
57 mid_size = triton.cdiv(M, block_size)
58 block_mid = triton.next_power_of_2(mid_size)
60 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)
61 out = torch.empty([], dtype=dtype, device=inp.device)
63 with torch_device_fn.device(inp.device):
64 prod_kernel_mid[(mid_size, 1, 1)](inp, mid, M, block_size)
65 prod_kernel_result[(1, 1, 1)](mid, out, mid_size, block_mid)
66 return out
69def heur_block_n(args):
70 return triton.next_power_of_2(args["N"])
73@libentry()
74@libtuner(
75 configs=runtime.get_tuned_config("prod"),
76 key=[
77 "M",
78 "N",
79 ],
80)
81@triton.heuristics(
82 {
83 "BLOCK_N": heur_block_n,
84 }
85)
86@triton.jit
87def prod_kernel(
88 inp,
89 out,
90 M,
91 N,
92 K,
93 BLOCK_M: tl.constexpr,
94 BLOCK_N: tl.constexpr,
95):
96 # set offset
97 pid_m = tle.program_id(0)
98 pid_k = tle.program_id(1)
99 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
101 acc = tl.full((BLOCK_M, BLOCK_N), value=1.0, dtype=tl.float32)
102 for start_n in range(0, N, BLOCK_N):
103 n_offset = start_n + tl.arange(0, BLOCK_N)
104 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
106 # set mask
107 mask = m_offset[:, None] < M and n_offset[None, :] < N
108 inp_ptrs = inp + offset
109 inp_vals = tl.load(inp_ptrs, mask=mask, other=1.0).to(tl.float32)
110 acc *= inp_vals
111 result_index = tl.reduce(acc, axis=1, combine_fn=reduce_mul)
113 offset_index = m_offset * K + pid_k
114 out_ptrs = out + offset_index
115 mask1 = m_offset < M
116 tl.store(out_ptrs, result_index, mask=mask1)
119def prod_dim(inp, dim=None, keepdim=False, *, dtype=None):
120 logger.debug("METAX GEMS PROD DIM")
122 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
123 shape = inp.shape
124 dim = dim % inp.ndim
125 N = shape[dim]
126 M = math.prod(shape[:dim])
127 K = inp.numel() // M // N
129 inp = inp.contiguous()
131 shape_list = list(shape)
132 shape_list[dim] = 1
134 if dtype is None:
135 dtype = inp.dtype
136 out = torch.empty(shape_list, dtype=dtype, device=inp.device)
137 if not keepdim:
138 out = torch.squeeze(out, dim)
140 grid = lambda meta: (
141 triton.cdiv(M, meta["BLOCK_M"]),
142 K,
143 )
144 with torch_device_fn.device(inp.device):
145 prod_kernel[grid](inp, out, M, N, K)
147 return out