Coverage for src/flag_gems/runtime/backend/_cambricon/ops/prod.py: 0%
97 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
12from ..utils import TOTAL_CORE_NUM, cfggen_reduce_op2, count_divisible_by_2
14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
17@triton.jit
18def reduce_mul(a, b):
19 return a * b
22@libentry()
23@triton.autotune(configs=cfggen_reduce_op2(), key=["M"])
24@triton.jit
25def prod_kernel_mid(
26 inp,
27 mid,
28 M,
29 BLOCK_SIZE: tl.constexpr,
30 ITER_NUM: tl.constexpr,
31):
32 pid = tl.program_id(0)
33 num_jobs = tl.num_programs(axis=0)
34 block_start = pid * BLOCK_SIZE
35 step = num_jobs * BLOCK_SIZE
36 _tmp = tl.full([BLOCK_SIZE], value=1.0, dtype=tl.float32)
37 block_start = block_start.to(tl.int64)
38 for off in range(block_start, M, step):
39 offset = off + tl.arange(0, BLOCK_SIZE)
40 mask = offset < M
41 inp_val = tl.load(inp + offset, mask=mask, other=1.0).to(tl.float32)
42 _tmp = inp_val * _tmp
44 # Reset to original reduce programming mode after optimizing the tl.reduce.
45 for x in tl.static_range(1, int(ITER_NUM), 1):
46 _tmp[: BLOCK_SIZE // (2**x)] = (
47 _tmp[: BLOCK_SIZE // (2**x)]
48 * _tmp[BLOCK_SIZE // (2**x) : (BLOCK_SIZE // (2**x)) * 2]
49 )
51 mid_ptr = mid + pid
52 tl.store(mid_ptr, _tmp[0])
55@libentry()
56@triton.jit
57def prod_kernel_result(mid, out, mid_size: tl.constexpr, loop_num: tl.constexpr):
58 offset = tl.arange(0, mid_size)
59 mid_val = tl.load(mid + offset)
61 # Reset to original reduce programming mode after optimizing the tl.reduce.
62 for x in tl.static_range(1, loop_num, 1):
63 mid_val[: mid_size // (2**x)] = (
64 mid_val[: mid_size // (2**x)]
65 * mid_val[mid_size // (2**x) : (mid_size // (2**x)) * 2]
66 )
68 prod_val = tl.reduce(
69 mid_val[: mid_size // (2 ** (loop_num - 1))], axis=0, combine_fn=reduce_mul
70 )
71 tl.store(out, prod_val)
74def prod(inp, *, dtype=None):
75 logger.debug("GEMS_CAMBRICON PROD")
76 if dtype is None:
77 dtype = inp.dtype
79 M = inp.numel()
80 grid = lambda meta: (min(triton.cdiv(M, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),)
81 mid_size = TOTAL_CORE_NUM
82 loop_num = count_divisible_by_2(mid_size) + 1
84 mid = torch.ones((mid_size,), dtype=dtype, device=inp.device)
85 out = torch.empty([], dtype=dtype, device=inp.device)
87 with torch_device_fn.device(inp.device):
88 prod_kernel_mid[grid](inp, mid, M)
89 prod_kernel_result[(1, 1, 1)](mid, out, mid_size, loop_num)
90 return out
93def heur_block_n(args):
94 return triton.next_power_of_2(args["N"])
97@libentry()
98@triton.autotune(
99 configs=runtime.get_tuned_config("prod"),
100 key=[
101 "M",
102 "N",
103 ],
104)
105@triton.jit
106def prod_kernel(
107 inp,
108 out,
109 M,
110 N,
111 K,
112 BLOCK_M: tl.constexpr,
113 BLOCK_N: tl.constexpr,
114):
115 # set offset
116 pid_m = tl.program_id(0)
117 pid_k = tl.program_id(1)
118 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
120 acc = tl.full((BLOCK_M, BLOCK_N), value=1.0, dtype=tl.float32)
121 for start_n in range(0, N, BLOCK_N):
122 n_offset = start_n + tl.arange(0, BLOCK_N)
123 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
125 # set mask
126 mask = (m_offset[:, None] < M) & (n_offset[None, :] < N)
127 inp_ptrs = inp + offset
128 inp_vals = tl.load(inp_ptrs, mask=mask, other=1.0).to(tl.float32)
129 acc *= inp_vals
130 result_index = tl.reduce(acc, axis=1, combine_fn=reduce_mul)
132 offset_index = m_offset * K + pid_k
133 out_ptrs = out + offset_index
134 mask1 = m_offset < M
135 tl.store(out_ptrs, result_index, mask=mask1)
138def prod_dim(inp, dim=None, keepdim=False, *, dtype=None):
139 logger.debug("GEMS_CAMBRICON PROD DIM")
141 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
142 shape = inp.shape
143 dim = dim % inp.ndim
144 N = shape[dim]
145 M = math.prod(shape[:dim])
146 K = inp.numel() // M // N
148 inp = inp.contiguous()
150 shape_list = list(shape)
151 shape_list[dim] = 1
153 if dtype is None:
154 dtype = inp.dtype
155 out = torch.empty(shape_list, dtype=dtype, device=inp.device)
156 if not keepdim:
157 out = torch.squeeze(out, dim)
159 grid = lambda meta: (
160 triton.cdiv(M, meta["BLOCK_M"]),
161 K,
162 )
163 with torch_device_fn.device(inp.device):
164 prod_kernel[grid](inp, out, M, N, K)
166 return out