Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/sum.py: 0%
148 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7# from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import dim_compress, libentry
10from flag_gems.utils import triton_lang_extension as tle
12from ..utils.block_size_utils import get_block_size_1d
14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
17@libentry()
18@triton.jit
19def sum_kernel_1(
20 inp,
21 mid,
22 M,
23 BLOCK_SIZE: tl.constexpr,
24):
25 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr(
26 inp.dtype.element_ty == tl.bfloat16
27 ):
28 cdtype = tl.float32
29 else:
30 cdtype = inp.dtype.element_ty
32 pid = tle.program_id(0)
33 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
34 inp_ptrs = inp + offset
35 mask = offset < M
37 inp_val = tl.load(inp_ptrs, mask=mask, other=0).to(cdtype)
38 sum_val = tl.sum(inp_val)
39 mid_ptr = mid + pid
40 tl.store(mid_ptr, sum_val)
43@libentry()
44@triton.jit
45def sum_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr):
46 if tl.constexpr(mid.dtype.element_ty == tl.float16) or tl.constexpr(
47 mid.dtype.element_ty == tl.bfloat16
48 ):
49 cdtype = tl.float32
50 else:
51 cdtype = mid.dtype.element_ty
53 offset = tl.arange(0, BLOCK_MID)
54 mid_ptrs = mid + offset
55 mask = offset < mid_size
56 mid_val = tl.load(mid_ptrs, mask=mask, other=0).to(cdtype)
57 sum_val = tl.sum(mid_val)
58 tl.store(out, sum_val)
61def heur_m_block_size(args):
62 return triton.next_power_of_2(triton.cdiv(args["M"], 12)) # cluster_num
65def heur_n_block_size(args):
66 import builtins
68 return builtins.min(triton.next_power_of_2(args["N"]), 8192)
71@libentry()
72@triton.heuristics(
73 values={
74 "BLOCK_M": heur_m_block_size,
75 "BLOCK_N": heur_n_block_size,
76 },
77)
78@triton.jit
79def sum_kernel(
80 inp,
81 out,
82 M,
83 N,
84 BLOCK_M: tl.constexpr,
85 BLOCK_N: tl.constexpr,
86):
87 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr(
88 inp.dtype.element_ty == tl.bfloat16
89 ):
90 cdtype = tl.float32
91 else:
92 cdtype = inp.dtype.element_ty
94 # Map the program id to the row of inp it should compute.
95 pid = tle.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
96 inp = inp + pid * N
97 out = out + pid
98 row_mask = pid < M
100 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=cdtype)
101 for off in range(0, N, BLOCK_N):
102 cols = off + tl.arange(0, BLOCK_N)[None, :]
103 col_mask = cols < N
104 mask = row_mask and col_mask
106 a = tl.load(inp + cols, mask, other=0).to(cdtype)
107 _sum += a
108 sum = tl.sum(_sum, axis=1)[:, None]
109 tl.store(out, sum, row_mask)
112def sum(inp, *, dtype=None):
113 logger.debug("GEMS SUM")
114 M = inp.numel()
115 if dtype is None:
116 dtype = inp.dtype
117 if dtype is torch.bool:
118 inp = inp.to(torch.int64)
119 dtype = torch.int64
120 block_size = get_block_size_1d(M, inp.element_size())
121 mid_size = triton.cdiv(M, block_size)
122 block_mid = triton.next_power_of_2(mid_size)
124 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)
125 out = torch.empty([], dtype=dtype, device=inp.device)
127 with torch_device_fn.device(inp.device):
128 sum_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size, buffer_size_limit=2048)
129 if mid_size == 1:
130 return mid.reshape([])
131 sum_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid, buffer_size_limit=2048)
132 return out
135def sum_out(inp, *, dtype=None, out):
136 logger.debug("GEMS SUM_OUT")
137 M = inp.numel()
138 if dtype is None:
139 dtype = inp.dtype
140 if dtype is torch.bool:
141 inp = inp.to(torch.int64)
142 dtype = torch.int64
143 block_size = get_block_size_1d(M, inp.element_size())
144 mid_size = triton.cdiv(M, block_size)
145 block_mid = triton.next_power_of_2(mid_size)
147 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)
148 with torch_device_fn.device(inp.device):
149 sum_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size, buffer_size_limit=2048)
150 if mid_size == 1:
151 return mid.reshape([])
152 sum_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid, buffer_size_limit=2048)
153 return out
156def sum_dim(inp, dim=None, keepdim=False, *, dtype=None):
157 logger.debug("GEMS SUM DIM")
158 if dtype is None:
159 dtype = inp.dtype
160 if dtype is torch.bool:
161 dtype = torch.int64
163 if dim == []:
164 if not keepdim:
165 return sum(inp, dtype=dtype)
166 else:
167 dim_num = inp.ndim
168 return torch.reshape(sum(inp, dtype=dtype), [1] * dim_num)
170 shape = list(inp.shape)
171 dim = [d % inp.ndim for d in dim]
172 inp = dim_compress(inp, dim)
173 N = 1
174 for i in dim:
175 N *= shape[i]
176 shape[i] = 1
177 M = inp.numel() // N
179 out = torch.empty(shape, dtype=dtype, device=inp.device)
181 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
182 with torch_device_fn.device(inp.device):
183 sum_kernel[grid](inp, out, M, N, buffer_size_limit=2048)
184 if not keepdim:
185 out = out.squeeze(dim=dim)
186 return out
189def sum_dim_out(inp, dim=None, keepdim=False, *, dtype=None, out):
190 logger.debug("GEMS SUM_DIM_OUT")
191 if dtype is None:
192 dtype = inp.dtype
193 if dtype is torch.bool:
194 dtype = torch.int64
196 if dim == []:
197 if not keepdim:
198 return sum_out(inp, dtype=dtype, out=out)
199 else:
200 dim_num = inp.ndim
201 return torch.reshape(sum_out(inp, dtype=dtype, out=out), [1] * dim_num)
203 shape = list(inp.shape)
204 dim = [d % inp.ndim for d in dim]
205 inp = dim_compress(inp, dim)
206 N = 1
207 for i in dim:
208 N *= shape[i]
209 shape[i] = 1
210 M = inp.numel() // N
212 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
213 with torch_device_fn.device(inp.device):
214 sum_kernel[grid](inp, out, M, N, buffer_size_limit=2048)
215 if not keepdim:
216 out.squeeze_(dim=dim)
217 return out