Coverage for src/flag_gems/runtime/backend/_sunrise/ops/sum.py: 0%
295 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
2import math
3from functools import reduce
5import torch
6import triton
7import triton.language as tl
9from flag_gems import runtime
10from flag_gems.ops.zeros import zero_
11from flag_gems.runtime import torch_device_fn
12from flag_gems.utils import dim_compress, libentry, libtuner
13from flag_gems.utils import triton_lang_extension as ext
15logger = logging.getLogger(__name__)
18@libentry()
19@triton.jit
20def sum_kernel_1(
21 inp,
22 mid,
23 M,
24 BLOCK_SIZE: tl.constexpr,
25):
26 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr(
27 inp.dtype.element_ty == tl.bfloat16
28 ):
29 cdtype = tl.float32
30 else:
31 cdtype = inp.dtype.element_ty
33 pid = ext.program_id(0)
34 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
35 inp_ptrs = inp + offset
36 mask = offset < M
38 inp_val = tl.load(inp_ptrs, mask=mask, other=0).to(cdtype)
39 sum_val = tl.sum(inp_val)
40 mid_ptr = mid + pid
41 tl.store(mid_ptr, sum_val)
44@libentry()
45@triton.jit
46def sum_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr):
47 if tl.constexpr(mid.dtype.element_ty == tl.float16) or tl.constexpr(
48 mid.dtype.element_ty == tl.bfloat16
49 ):
50 cdtype = tl.float32
51 else:
52 cdtype = mid.dtype.element_ty
54 offset = tl.arange(0, BLOCK_MID)
55 mid_ptrs = mid + offset
56 mask = offset < mid_size
57 mid_val = tl.load(mid_ptrs, mask=mask, other=0).to(cdtype)
58 sum_val = tl.sum(mid_val)
59 tl.store(out, sum_val)
62@libentry()
63@triton.autotune(configs=runtime.get_tuned_config("sum"), key=["M", "N"])
64@triton.jit
65def sum_kernel_dim0(
66 inp,
67 out,
68 M,
69 N,
70 BLOCK_M: tl.constexpr,
71 BLOCK_N: tl.constexpr,
72):
73 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr(
74 inp.dtype.element_ty == tl.bfloat16
75 ):
76 cdtype = tl.float32
77 else:
78 cdtype = inp.dtype.element_ty
80 # Map the program id to the row of inp it should compute.
81 pid = ext.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[None, :]
82 inp = inp + pid
83 out = out + pid
84 row_mask = pid < M
86 _sum = tl.zeros([BLOCK_N, BLOCK_M], dtype=cdtype)
87 for off in range(0, N, BLOCK_N):
88 cols = off + tl.arange(0, BLOCK_N)[:, None]
89 col_mask = cols < N
90 mask = row_mask & col_mask
92 a = tl.load(inp + cols * M, mask, other=0).to(cdtype)
93 _sum += a
94 sum = tl.sum(_sum, axis=0)[None, :]
95 tl.store(out, sum, row_mask)
98def sum(inp, *, dtype=None):
99 logger.debug("GEMS SUM")
100 inp = inp.contiguous()
101 M = inp.numel()
102 if dtype is None:
103 dtype = inp.dtype
104 if dtype is torch.bool:
105 inp = inp.to(torch.int64)
106 dtype = torch.int64
107 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
108 mid_size = triton.cdiv(M, block_size)
109 block_mid = triton.next_power_of_2(mid_size)
111 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)
112 out = torch.empty([], dtype=dtype, device=inp.device)
114 with torch_device_fn.device(inp.device):
115 sum_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size)
116 sum_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid)
117 return out
120def sum_out(inp, *, dtype=None, out):
121 logger.debug("GEMS SUM_OUT")
122 M = inp.numel()
123 if dtype is None:
124 dtype = inp.dtype
125 if dtype is torch.bool:
126 inp = inp.to(torch.int64)
127 dtype = torch.int64
128 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
129 mid_size = triton.cdiv(M, block_size)
130 block_mid = triton.next_power_of_2(mid_size)
132 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)
133 with torch_device_fn.device(inp.device):
134 sum_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size)
135 sum_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid)
136 return out
139@libentry()
140@triton.heuristics(runtime.get_heuristic_config("softmax_non_inner"))
141@triton.jit
142def sum_dim_kernel_non_inner(
143 output_ptr,
144 input_ptr,
145 M,
146 N,
147 K,
148 TILE_N: tl.constexpr,
149 TILE_K: tl.constexpr,
150 ONE_TILE_PER_CTA: tl.constexpr,
151):
152 if tl.constexpr(input_ptr.dtype.element_ty == tl.float16) or tl.constexpr(
153 input_ptr.dtype.element_ty == tl.bfloat16
154 ):
155 cdtype = tl.float32
156 else:
157 cdtype = input_ptr.dtype.element_ty
159 pid_m = ext.program_id(0)
160 pid_k = ext.program_id(1)
162 k_offsets = pid_k * TILE_K + tl.arange(0, TILE_K)[None, :]
164 if ONE_TILE_PER_CTA:
165 n_offsets = tl.arange(0, TILE_N)[:, None]
166 inp_offset = pid_m * N * K + n_offsets * K + k_offsets
167 mask = (n_offsets < N) & (k_offsets < K)
168 input_ptrs = input_ptr + inp_offset
169 inp = tl.load(input_ptrs, mask=mask, other=0).to(cdtype)
170 out = tl.sum(inp, axis=0, keep_dims=True)
171 out_offset = pid_m * K + k_offsets
172 output_ptrs = output_ptr + out_offset
173 tl.store(output_ptrs, out, mask=k_offsets < K)
174 else:
175 sum = tl.zeros([TILE_N, TILE_K], dtype=cdtype)
177 # specialization does not improve performance inn this example, as tested
178 for start_n in range(0, N, TILE_N):
179 n_offsets = start_n + tl.arange(0, TILE_N)[:, None]
180 inp_offsets = pid_m * N * K + n_offsets * K + k_offsets
181 mask = (n_offsets < N) & (k_offsets < K)
182 inp = tl.load(input_ptr + inp_offsets, mask=mask, other=0).to(cdtype)
183 sum += inp
184 out = tl.sum(sum, axis=0, keep_dims=True)
185 out_offset = pid_m * K + k_offsets
186 output_ptrs = output_ptr + out_offset
187 tl.store(output_ptrs, out, mask=k_offsets < K)
190@libentry()
191@triton.heuristics(runtime.get_heuristic_config("softmax_inner"))
192@triton.jit
193def sum_dim_kernel_inner(
194 output_ptr,
195 input_ptr,
196 M,
197 N,
198 TILE_N: tl.constexpr,
199 ONE_TILE_PER_CTA: tl.constexpr,
200):
201 if tl.constexpr(input_ptr.dtype.element_ty == tl.float16) or tl.constexpr(
202 input_ptr.dtype.element_ty == tl.bfloat16
203 ):
204 cdtype = tl.float32
205 else:
206 cdtype = input_ptr.dtype.element_ty
208 pid_m = ext.program_id(0)
209 if ONE_TILE_PER_CTA:
210 n_offsets = tl.arange(0, TILE_N)
211 inp_offset = pid_m * N + n_offsets
212 input_ptrs = input_ptr + inp_offset
213 mask = n_offsets < N
214 inp = tl.load(input_ptrs, mask=mask, other=0).to(cdtype)
215 out = tl.sum(inp, axis=0)
216 out_offset = pid_m
217 output_ptrs = output_ptr + out_offset
218 tl.store(output_ptrs, out)
219 else:
220 sum = tl.zeros(
221 [
222 TILE_N,
223 ],
224 dtype=cdtype,
225 )
226 for start_n in range(0, N, TILE_N):
227 n_offsets = start_n + tl.arange(0, TILE_N)
228 inp_offsets = pid_m * N + n_offsets
229 mask = n_offsets < N
230 inp = tl.load(input_ptr + inp_offsets, mask=mask, other=0).to(cdtype)
231 sum += inp
232 out = tl.sum(sum, axis=0)
233 out_offset = pid_m
234 output_ptrs = output_ptr + out_offset
235 tl.store(output_ptrs, out)
238@libentry()
239@libtuner(
240 configs=runtime.get_tuned_config("naive_reduction"),
241 key=["M", "N"],
242)
243@triton.jit
244def sum_dim_kernel(
245 inp,
246 out,
247 M,
248 N,
249 BLOCK_M: tl.constexpr,
250 BLOCK_N: tl.constexpr,
251):
252 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr(
253 inp.dtype.element_ty == tl.bfloat16
254 ):
255 cdtype = tl.float32
256 else:
257 cdtype = inp.dtype.element_ty
259 # Map the program id to the row of inp it should compute.
260 pid = ext.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
261 inp = inp + pid * N
262 out = out + pid
263 row_mask = pid < M
265 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=cdtype)
266 for off in range(0, N, BLOCK_N):
267 cols = off + tl.arange(0, BLOCK_N)[None, :]
268 col_mask = cols < N
269 mask = row_mask and col_mask
271 a = tl.load(inp + cols, mask, other=0).to(cdtype)
272 _sum += a
273 sum = tl.sum(_sum, axis=1)[:, None]
274 tl.store(out, sum, row_mask)
277def sum_dim_comm(inp, dim=None, keepdim=False, *, dtype=None, out=None):
278 if dtype is None:
279 dtype = inp.dtype
280 if dtype is torch.bool:
281 dtype = torch.int64
283 if dim is None:
284 result = torch.sum(inp, dtype=dtype)
285 if keepdim:
286 result = result.reshape([1] * inp.ndim)
287 return result
289 if dim == []:
290 if not keepdim:
291 return sum(inp, dtype=dtype)
292 else:
293 dim_num = inp.ndim
294 return torch.reshape(sum(inp, dtype=dtype), [1] * dim_num)
296 shape = list(inp.shape)
297 dim = [d % inp.ndim for d in dim]
299 if check_dim0(inp, dim):
300 return sum_dim0(inp, dim, keepdim, dtype, out=out)
302 if len(dim) == 1:
303 dim = dim[0]
304 N = inp.shape[dim]
305 M = reduce(lambda x, y: x * y, shape[:dim], 1)
306 inp = inp.contiguous()
307 K = inp.numel() // M // N
308 shape[dim] = 1
309 _out_provided = out is not None
310 if _out_provided:
311 # Resize out to the expected output shape, matching native PyTorch
312 # sum.out behavior. The caller (e.g. logsumexp) may pass a
313 # zero-size placeholder that needs to be resized before use.
314 if keepdim:
315 out.resize_(shape)
316 else:
317 out.resize_(shape[:dim] + shape[dim + 1 :])
318 else:
319 out = torch.empty(shape, dtype=dtype, device=inp.device)
321 with torch_device_fn.device(inp.device):
322 if K > 1:
323 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1)
324 sum_dim_kernel_non_inner[grid](
325 out,
326 inp,
327 M,
328 N,
329 K,
330 )
331 else:
332 grid = (M, 1, 1)
333 sum_dim_kernel_inner[grid](
334 out,
335 inp,
336 M,
337 N,
338 )
339 if not keepdim and not _out_provided:
340 out = out.squeeze(dim=dim)
341 return out
342 else:
343 inp = dim_compress(inp, dim)
344 N = 1
345 for i in dim:
346 N *= shape[i]
347 shape[i] = 1
348 M = inp.numel() // N
349 _out_provided = out is not None
350 if _out_provided:
351 dim_set = set(dim)
352 if keepdim:
353 out.resize_(shape)
354 else:
355 out.resize_([s for i, s in enumerate(shape) if i not in dim_set])
356 else:
357 out = torch.empty(shape, dtype=dtype, device=inp.device)
359 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
360 with torch_device_fn.device(inp.device):
361 sum_dim_kernel[grid](inp, out, M, N)
362 if not keepdim and not _out_provided:
363 for d in sorted(dim, reverse=True):
364 out = out.squeeze(dim=d)
365 return out
368def check_dim0(inp, dim):
369 shape = list(inp.shape)
370 if len(shape) == len(dim):
371 return False
372 for i in dim:
373 shape[i] = 1
374 if shape == [1] * len(shape):
375 return False
377 for i in range(max(dim)):
378 if shape[i] > 1:
379 return False
380 return True
383def sum_dim0(inp, dim, keepdim, dtype, out=None):
384 shape = list(inp.shape)
385 N = 1
386 for i in dim:
387 N *= shape[i]
388 shape[i] = 1
389 M = inp.numel() // N
390 _out_provided = out is not None
391 if _out_provided:
392 if keepdim:
393 out.resize_(shape)
394 else:
395 out.resize_([s for i, s in enumerate(shape) if i not in set(dim)])
396 else:
397 out = torch.empty(shape, dtype=dtype, device=inp.device)
398 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
399 with torch_device_fn.device(inp.device):
400 sum_kernel_dim0[grid](inp, out, M, N)
401 if not keepdim and not _out_provided:
402 out = out.squeeze(dim=dim)
403 return out
406def sum_dim(inp, dim=None, keepdim=False, *, dtype=None):
407 logger.debug("GEMS SUM_DIM")
408 # support dim = 0, which are consistent with PyTorch
409 if inp.numel() == 0:
410 if dtype is None:
411 dtype = inp.dtype
412 if dtype is torch.bool:
413 dtype = torch.int64
415 out_shape = list(inp.shape)
416 if dim is None:
417 if keepdim:
418 out_shape = [1] * len(out_shape)
419 else:
420 out_shape = []
421 elif isinstance(dim, (list, tuple)) and len(dim) == 0:
422 if keepdim:
423 out_shape = [1] * len(out_shape)
424 else:
425 out_shape = []
426 else:
427 dims_to_reduce = dim if isinstance(dim, (list, tuple)) else [dim]
428 if keepdim:
429 for d in dims_to_reduce:
430 out_shape[d % inp.ndim] = 1
431 else:
432 sorted_dims_to_remove = sorted(
433 dims_to_reduce, key=lambda x: x % inp.ndim, reverse=True
434 )
435 for d in sorted_dims_to_remove:
436 index_to_remove = d % inp.ndim
437 out_shape.pop(index_to_remove)
438 out = torch.empty(out_shape, dtype=dtype, device=inp.device)
439 zero_(out)
440 return out
441 return sum_dim_comm(inp, dim, keepdim, dtype=dtype)
444def sum_dim_out(inp, dim=None, keepdim=False, *, dtype=None, out):
445 logger.debug("GEMS SUM_DIM_OUT")
446 return sum_dim_comm(inp, dim, keepdim, dtype=dtype, out=out)