Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/prod.py: 0%
90 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7# from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import dim_compress, libentry
10from flag_gems.utils import triton_lang_extension as tle
12from ..utils.block_size_utils import get_block_size_1d
14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
17@triton.jit
18def reduce_mul(a, b):
19 return a * b
22@libentry()
23@triton.jit
24def prod_kernel_mid(
25 inp,
26 mid,
27 M,
28 BLOCK_SIZE: tl.constexpr,
29):
30 pid = tle.program_id(0)
31 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
32 inp_ptrs = inp + offset
33 mask = offset < M
34 inp_val = tl.load(inp_ptrs, mask=mask, other=1.0).to(tl.float32)
35 mid_value = tl.reduce(inp_val, axis=0, combine_fn=reduce_mul)
36 mid_ptr = mid + pid
37 tl.store(mid_ptr, mid_value.to(inp_val.dtype))
40@libentry()
41@triton.jit
42def prod_kernel_result(mid, out, mid_size, BLOCK_MID: tl.constexpr):
43 offset = tl.arange(0, BLOCK_MID)
44 mid_ptrs = mid + offset
45 mask = offset < mid_size
46 mid_val = tl.load(mid_ptrs, mask=mask, other=1.0).to(tl.float32)
47 prod_val = tl.reduce(mid_val, axis=0, combine_fn=reduce_mul)
48 tl.store(out, prod_val)
51def prod(inp, *, dtype=None):
52 logger.debug("GEMS PROD")
53 if dtype is None:
54 dtype = inp.dtype
56 M = inp.numel()
57 # block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
58 block_size = get_block_size_1d(M, inp.element_size())
59 mid_size = triton.cdiv(M, block_size)
60 block_mid = triton.next_power_of_2(mid_size)
62 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)
63 out = torch.empty([], dtype=dtype, device=inp.device)
65 with torch_device_fn.device(inp.device):
66 prod_kernel_mid[(mid_size, 1, 1)](
67 inp, mid, M, block_size, buffer_size_limit=2048
68 )
69 if mid_size == 1:
70 return mid.reshape([])
71 prod_kernel_result[(1, 1, 1)](
72 mid, out, mid_size, block_mid, buffer_size_limit=2048
73 )
74 return out
77def heur_m_block_size(args):
78 return triton.next_power_of_2(triton.cdiv(args["M"], 12)) # cluster_num
81def heur_n_block_size(args):
82 import builtins
84 return builtins.min(triton.next_power_of_2(args["N"]), 8192)
87@libentry()
88@triton.heuristics(
89 values={
90 "BLOCK_M": heur_m_block_size,
91 "BLOCK_N": heur_n_block_size,
92 },
93)
94@triton.jit
95def prod_kernel(
96 inp,
97 out,
98 M,
99 N,
100 BLOCK_M: tl.constexpr,
101 BLOCK_N: tl.constexpr,
102):
103 # set offset
104 pid_m = tle.program_id(0)
105 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
107 acc = tl.full((BLOCK_M, BLOCK_N), value=1.0, dtype=tl.float32)
108 for start_n in range(0, N, BLOCK_N):
109 n_offset = start_n + tl.arange(0, BLOCK_N)
110 offset = m_offset[:, None] * N + n_offset[None, :]
112 # set mask
113 mask = m_offset[:, None] < M and n_offset[None, :] < N
114 inp_ptrs = inp + offset
115 inp_vals = tl.load(inp_ptrs, mask=mask, other=1.0).to(tl.float32)
116 acc *= inp_vals
117 result_index = tl.reduce(acc, axis=1, combine_fn=reduce_mul)
119 offset_index = m_offset
120 out_ptrs = out + offset_index
121 mask1 = m_offset < M
122 tl.store(out_ptrs, result_index, mask=mask1)
125def prod_dim(inp, dim=None, keepdim=False, *, dtype=None):
126 logger.debug("GEMS PROD DIM")
128 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
129 shape = list(inp.shape)
130 dim = dim % inp.ndim
131 inp = dim_compress(inp, dim)
132 N = shape[dim]
133 shape[dim] = 1
134 M = inp.numel() // N
136 if dtype is None:
137 dtype = inp.dtype
138 out = torch.empty(shape, dtype=dtype, device=inp.device)
139 if not keepdim:
140 out = torch.squeeze(out, dim)
142 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
143 with torch_device_fn.device(inp.device):
144 prod_kernel[grid](inp, out, M, N, buffer_size_limit=2048)
146 return out