Coverage for src/flag_gems/ops/sum.py: 45%
232 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1import logging
2import math
3from functools import reduce
5import torch
6import triton
7import triton.language as tl
9from flag_gems import runtime
10from flag_gems.ops.zeros import zero_
11from flag_gems.runtime import torch_device_fn
12from flag_gems.utils import dim_compress, libentry, libtuner
13from flag_gems.utils import triton_lang_extension as tle
15logger = logging.getLogger(__name__)
18@libentry()
19@triton.jit
20def sum_kernel_1(
21 inp,
22 mid,
23 M,
24 BLOCK_SIZE: tl.constexpr,
25):
26 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr(
27 inp.dtype.element_ty == tl.bfloat16
28 ):
29 cdtype = tl.float32
30 else:
31 cdtype = inp.dtype.element_ty
33 pid = tle.program_id(0)
34 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
35 inp_ptrs = inp + offset
36 mask = offset < M
38 inp_val = tl.load(inp_ptrs, mask=mask, other=0).to(cdtype)
39 sum_val = tl.sum(inp_val)
40 mid_ptr = mid + pid
41 tl.store(mid_ptr, sum_val)
44@libentry()
45@triton.jit
46def sum_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr):
47 if tl.constexpr(mid.dtype.element_ty == tl.float16) or tl.constexpr(
48 mid.dtype.element_ty == tl.bfloat16
49 ):
50 cdtype = tl.float32
51 else:
52 cdtype = mid.dtype.element_ty
54 offset = tl.arange(0, BLOCK_MID)
55 mid_ptrs = mid + offset
56 mask = offset < mid_size
57 mid_val = tl.load(mid_ptrs, mask=mask, other=0).to(cdtype)
58 sum_val = tl.sum(mid_val)
59 tl.store(out, sum_val)
62def sum(inp, *, dtype=None):
63 logger.debug("GEMS SUM")
64 inp = inp.contiguous()
65 M = inp.numel()
66 if dtype is None:
67 dtype = inp.dtype
68 if dtype is torch.bool:
69 inp = inp.to(torch.int64)
70 dtype = torch.int64
71 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
72 mid_size = triton.cdiv(M, block_size)
73 block_mid = triton.next_power_of_2(mid_size)
75 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)
76 out = torch.empty([], dtype=dtype, device=inp.device)
78 with torch_device_fn.device(inp.device):
79 sum_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size)
80 sum_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid)
81 return out
84def sum_out(inp, *, dtype=None, out):
85 logger.debug("GEMS SUM_OUT")
86 M = inp.numel()
87 if dtype is None:
88 dtype = inp.dtype
89 if dtype is torch.bool:
90 inp = inp.to(torch.int64)
91 dtype = torch.int64
92 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
93 mid_size = triton.cdiv(M, block_size)
94 block_mid = triton.next_power_of_2(mid_size)
96 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)
97 with torch_device_fn.device(inp.device):
98 sum_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size)
99 sum_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid)
100 return out
103@libentry()
104@triton.heuristics(runtime.get_heuristic_config("softmax_non_inner"))
105@triton.jit
106def sum_dim_kernel_non_inner(
107 output_ptr,
108 input_ptr,
109 M,
110 N,
111 K,
112 TILE_N: tl.constexpr,
113 TILE_K: tl.constexpr,
114 ONE_TILE_PER_CTA: tl.constexpr,
115):
116 if tl.constexpr(input_ptr.dtype.element_ty == tl.float16) or tl.constexpr(
117 input_ptr.dtype.element_ty == tl.bfloat16
118 ):
119 cdtype = tl.float32
120 else:
121 cdtype = input_ptr.dtype.element_ty
123 pid_m = tle.program_id(0)
124 pid_k = tle.program_id(1)
126 k_offsets = pid_k * TILE_K + tl.arange(0, TILE_K)[None, :]
128 if ONE_TILE_PER_CTA:
129 n_offsets = tl.arange(0, TILE_N)[:, None]
130 inp_offset = pid_m * N * K + n_offsets * K + k_offsets
131 mask = (n_offsets < N) & (k_offsets < K)
132 input_ptrs = input_ptr + inp_offset
133 inp = tl.load(input_ptrs, mask=mask, other=0).to(cdtype)
134 out = tl.sum(inp, axis=0, keep_dims=True)
135 out_offset = pid_m * K + k_offsets
136 output_ptrs = output_ptr + out_offset
137 tl.store(output_ptrs, out, mask=k_offsets < K)
138 else:
139 sum = tl.zeros([TILE_N, TILE_K], dtype=cdtype)
141 # specialization does not improve performance inn this example, as tested
142 for start_n in range(0, N, TILE_N):
143 n_offsets = start_n + tl.arange(0, TILE_N)[:, None]
144 inp_offsets = pid_m * N * K + n_offsets * K + k_offsets
145 mask = (n_offsets < N) & (k_offsets < K)
146 inp = tl.load(input_ptr + inp_offsets, mask=mask, other=0).to(cdtype)
147 sum += inp
148 out = tl.sum(sum, axis=0, keep_dims=True)
149 out_offset = pid_m * K + k_offsets
150 output_ptrs = output_ptr + out_offset
151 tl.store(output_ptrs, out, mask=k_offsets < K)
154@libentry()
155@triton.heuristics(runtime.get_heuristic_config("softmax_inner"))
156@triton.jit
157def sum_dim_kernel_inner(
158 output_ptr,
159 input_ptr,
160 M,
161 N,
162 TILE_N: tl.constexpr,
163 ONE_TILE_PER_CTA: tl.constexpr,
164):
165 if tl.constexpr(input_ptr.dtype.element_ty == tl.float16) or tl.constexpr(
166 input_ptr.dtype.element_ty == tl.bfloat16
167 ):
168 cdtype = tl.float32
169 else:
170 cdtype = input_ptr.dtype.element_ty
172 pid_m = tle.program_id(0)
173 if ONE_TILE_PER_CTA:
174 n_offsets = tl.arange(0, TILE_N)
175 inp_offset = pid_m * N + n_offsets
176 input_ptrs = input_ptr + inp_offset
177 mask = n_offsets < N
178 inp = tl.load(input_ptrs, mask=mask, other=0).to(cdtype)
179 out = tl.sum(inp, axis=0)
180 out_offset = pid_m
181 output_ptrs = output_ptr + out_offset
182 tl.store(output_ptrs, out)
183 else:
184 sum = tl.zeros(
185 [
186 TILE_N,
187 ],
188 dtype=cdtype,
189 )
190 for start_n in range(0, N, TILE_N):
191 n_offsets = start_n + tl.arange(0, TILE_N)
192 inp_offsets = pid_m * N + n_offsets
193 mask = n_offsets < N
194 inp = tl.load(input_ptr + inp_offsets, mask=mask, other=0).to(cdtype)
195 sum += inp
196 out = tl.sum(sum, axis=0)
197 out_offset = pid_m
198 output_ptrs = output_ptr + out_offset
199 tl.store(output_ptrs, out)
202@libentry()
203@libtuner(
204 configs=runtime.get_tuned_config("naive_reduction"),
205 key=["M", "N"],
206)
207@triton.jit
208def sum_dim_kernel(
209 inp,
210 out,
211 M,
212 N,
213 BLOCK_M: tl.constexpr,
214 BLOCK_N: tl.constexpr,
215):
216 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr(
217 inp.dtype.element_ty == tl.bfloat16
218 ):
219 cdtype = tl.float32
220 else:
221 cdtype = inp.dtype.element_ty
223 # Map the program id to the row of inp it should compute.
224 pid = tle.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
225 inp = inp + pid * N
226 out = out + pid
227 row_mask = pid < M
229 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=cdtype)
230 for off in range(0, N, BLOCK_N):
231 cols = off + tl.arange(0, BLOCK_N)[None, :]
232 col_mask = cols < N
233 mask = row_mask and col_mask
235 a = tl.load(inp + cols, mask, other=0).to(cdtype)
236 _sum += a
237 sum = tl.sum(_sum, axis=1)[:, None]
238 tl.store(out, sum, row_mask)
241def sum_dim_comm(inp, dim=None, keepdim=False, *, dtype=None, out=None):
242 if dtype is None:
243 dtype = inp.dtype
244 if dtype is torch.bool:
245 dtype = torch.int64
247 if dim is None:
248 result = torch.sum(inp, dtype=dtype)
249 if keepdim:
250 result = result.reshape([1] * inp.ndim)
251 return result
253 if dim == []:
254 if not keepdim:
255 return sum(inp, dtype=dtype)
256 else:
257 dim_num = inp.ndim
258 return torch.reshape(sum(inp, dtype=dtype), [1] * dim_num)
260 shape = list(inp.shape)
261 dim = [d % inp.ndim for d in dim]
263 if len(dim) == 1:
264 dim = dim[0]
265 N = inp.shape[dim]
266 M = reduce(lambda x, y: x * y, shape[:dim], 1)
267 inp = inp.contiguous()
268 K = inp.numel() // M // N
269 shape[dim] = 1
270 if out is None:
271 out = torch.empty(shape, dtype=dtype, device=inp.device)
273 with torch_device_fn.device(inp.device):
274 if K > 1:
275 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1)
276 sum_dim_kernel_non_inner[grid](
277 out,
278 inp,
279 M,
280 N,
281 K,
282 )
283 else:
284 grid = (M, 1, 1)
285 sum_dim_kernel_inner[grid](
286 out,
287 inp,
288 M,
289 N,
290 )
291 if not keepdim:
292 out = out.squeeze(dim=dim)
293 return out
294 else:
295 inp = dim_compress(inp, dim)
296 N = 1
297 for i in dim:
298 N *= shape[i]
299 shape[i] = 1
300 M = inp.numel() // N
301 if out is None:
302 out = torch.empty(shape, dtype=dtype, device=inp.device)
304 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
305 with torch_device_fn.device(inp.device):
306 sum_dim_kernel[grid](inp, out, M, N)
307 if not keepdim:
308 out = out.squeeze(dim=dim)
309 return out
312def sum_dim(inp, dim=None, keepdim=False, *, dtype=None):
313 logger.debug("GEMS SUM_DIM")
314 # support dim = 0, which are consistent with PyTorch
315 if inp.numel() == 0:
316 if dtype is None:
317 dtype = inp.dtype
318 if dtype is torch.bool:
319 dtype = torch.int64
321 out_shape = list(inp.shape)
322 if dim is None:
323 if keepdim:
324 out_shape = [1] * len(out_shape)
325 else:
326 out_shape = []
327 elif isinstance(dim, (list, tuple)) and len(dim) == 0:
328 if keepdim:
329 out_shape = [1] * len(out_shape)
330 else:
331 out_shape = []
332 else:
333 dims_to_reduce = dim if isinstance(dim, (list, tuple)) else [dim]
334 if keepdim:
335 for d in dims_to_reduce:
336 out_shape[d % inp.ndim] = 1
337 else:
338 sorted_dims_to_remove = sorted(
339 dims_to_reduce, key=lambda x: x % inp.ndim, reverse=True
340 )
341 for d in sorted_dims_to_remove:
342 index_to_remove = d % inp.ndim
343 out_shape.pop(index_to_remove)
344 out = torch.empty(out_shape, dtype=dtype, device=inp.device)
345 zero_(out)
346 return out
347 return sum_dim_comm(inp, dim, keepdim, dtype=dtype)
350def sum_dim_out(inp, dim=None, keepdim=False, *, dtype=None, out):
351 logger.debug("GEMS SUM_DIM_OUT")
352 return sum_dim_comm(inp, dim, keepdim, dtype=dtype, out=out)