Coverage for src/flag_gems/ops/sum.py: 40%
206 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
1import logging
2import math
3from functools import reduce
5import torch
6import triton
7import triton.language as tl
9from flag_gems import runtime
10from flag_gems.runtime import torch_device_fn
11from flag_gems.utils import dim_compress, libentry, libtuner
12from flag_gems.utils import triton_lang_extension as tle
14logger = logging.getLogger(__name__)
17@libentry()
18@triton.jit
19def sum_kernel_1(
20 inp,
21 mid,
22 M,
23 BLOCK_SIZE: tl.constexpr,
24):
25 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr(
26 inp.dtype.element_ty == tl.bfloat16
27 ):
28 cdtype = tl.float32
29 else:
30 cdtype = inp.dtype.element_ty
32 pid = tle.program_id(0)
33 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
34 inp_ptrs = inp + offset
35 mask = offset < M
37 inp_val = tl.load(inp_ptrs, mask=mask, other=0).to(cdtype)
38 sum_val = tl.sum(inp_val)
39 mid_ptr = mid + pid
40 tl.store(mid_ptr, sum_val)
43@libentry()
44@triton.jit
45def sum_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr):
46 if tl.constexpr(mid.dtype.element_ty == tl.float16) or tl.constexpr(
47 mid.dtype.element_ty == tl.bfloat16
48 ):
49 cdtype = tl.float32
50 else:
51 cdtype = mid.dtype.element_ty
53 offset = tl.arange(0, BLOCK_MID)
54 mid_ptrs = mid + offset
55 mask = offset < mid_size
56 mid_val = tl.load(mid_ptrs, mask=mask, other=0).to(cdtype)
57 sum_val = tl.sum(mid_val)
58 tl.store(out, sum_val)
61def sum(inp, *, dtype=None):
62 logger.debug("GEMS SUM")
63 inp = inp.contiguous()
64 M = inp.numel()
65 if dtype is None:
66 dtype = inp.dtype
67 if dtype is torch.bool:
68 inp = inp.to(torch.int64)
69 dtype = torch.int64
70 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
71 mid_size = triton.cdiv(M, block_size)
72 block_mid = triton.next_power_of_2(mid_size)
74 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)
75 out = torch.empty([], dtype=dtype, device=inp.device)
77 with torch_device_fn.device(inp.device):
78 sum_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size)
79 sum_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid)
80 return out
83def sum_out(inp, *, dtype=None, out):
84 logger.debug("GEMS SUM_OUT")
85 M = inp.numel()
86 if dtype is None:
87 dtype = inp.dtype
88 if dtype is torch.bool:
89 inp = inp.to(torch.int64)
90 dtype = torch.int64
91 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
92 mid_size = triton.cdiv(M, block_size)
93 block_mid = triton.next_power_of_2(mid_size)
95 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)
96 with torch_device_fn.device(inp.device):
97 sum_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size)
98 sum_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid)
99 return out
102@libentry()
103@triton.heuristics(runtime.get_heuristic_config("softmax_non_inner"))
104@triton.jit
105def sum_dim_kernel_non_inner(
106 output_ptr,
107 input_ptr,
108 M,
109 N,
110 K,
111 TILE_N: tl.constexpr,
112 TILE_K: tl.constexpr,
113 ONE_TILE_PER_CTA: tl.constexpr,
114):
115 if tl.constexpr(input_ptr.dtype.element_ty == tl.float16) or tl.constexpr(
116 input_ptr.dtype.element_ty == tl.bfloat16
117 ):
118 cdtype = tl.float32
119 else:
120 cdtype = input_ptr.dtype.element_ty
122 pid_m = tle.program_id(0)
123 pid_k = tle.program_id(1)
125 k_offsets = pid_k * TILE_K + tl.arange(0, TILE_K)[None, :]
127 if ONE_TILE_PER_CTA:
128 n_offsets = tl.arange(0, TILE_N)[:, None]
129 inp_offset = pid_m * N * K + n_offsets * K + k_offsets
130 mask = (n_offsets < N) & (k_offsets < K)
131 input_ptrs = input_ptr + inp_offset
132 inp = tl.load(input_ptrs, mask=mask, other=0).to(cdtype)
133 out = tl.sum(inp, axis=0, keep_dims=True)
134 out_offset = pid_m * K + k_offsets
135 output_ptrs = output_ptr + out_offset
136 tl.store(output_ptrs, out, mask=k_offsets < K)
137 else:
138 sum = tl.zeros([TILE_N, TILE_K], dtype=cdtype)
140 # specialization does not improve performance inn this example, as tested
141 for start_n in range(0, N, TILE_N):
142 n_offsets = start_n + tl.arange(0, TILE_N)[:, None]
143 inp_offsets = pid_m * N * K + n_offsets * K + k_offsets
144 mask = (n_offsets < N) & (k_offsets < K)
145 inp = tl.load(input_ptr + inp_offsets, mask=mask, other=0).to(cdtype)
146 sum += inp
147 out = tl.sum(sum, axis=0, keep_dims=True)
148 out_offset = pid_m * K + k_offsets
149 output_ptrs = output_ptr + out_offset
150 tl.store(output_ptrs, out, mask=k_offsets < K)
153@libentry()
154@triton.heuristics(runtime.get_heuristic_config("softmax_inner"))
155@triton.jit
156def sum_dim_kernel_inner(
157 output_ptr,
158 input_ptr,
159 M,
160 N,
161 TILE_N: tl.constexpr,
162 ONE_TILE_PER_CTA: tl.constexpr,
163):
164 if tl.constexpr(input_ptr.dtype.element_ty == tl.float16) or tl.constexpr(
165 input_ptr.dtype.element_ty == tl.bfloat16
166 ):
167 cdtype = tl.float32
168 else:
169 cdtype = input_ptr.dtype.element_ty
171 pid_m = tle.program_id(0)
172 if ONE_TILE_PER_CTA:
173 n_offsets = tl.arange(0, TILE_N)
174 inp_offset = pid_m * N + n_offsets
175 input_ptrs = input_ptr + inp_offset
176 mask = n_offsets < N
177 inp = tl.load(input_ptrs, mask=mask, other=0).to(cdtype)
178 out = tl.sum(inp, axis=0)
179 out_offset = pid_m
180 output_ptrs = output_ptr + out_offset
181 tl.store(output_ptrs, out)
182 else:
183 sum = tl.zeros(
184 [
185 TILE_N,
186 ],
187 dtype=cdtype,
188 )
189 for start_n in range(0, N, TILE_N):
190 n_offsets = start_n + tl.arange(0, TILE_N)
191 inp_offsets = pid_m * N + n_offsets
192 mask = n_offsets < N
193 inp = tl.load(input_ptr + inp_offsets, mask=mask, other=0).to(cdtype)
194 sum += inp
195 out = tl.sum(sum, axis=0)
196 out_offset = pid_m
197 output_ptrs = output_ptr + out_offset
198 tl.store(output_ptrs, out)
201@libentry()
202@libtuner(
203 configs=runtime.get_tuned_config("naive_reduction"),
204 key=["M", "N"],
205)
206@triton.jit
207def sum_dim_kernel(
208 inp,
209 out,
210 M,
211 N,
212 BLOCK_M: tl.constexpr,
213 BLOCK_N: tl.constexpr,
214):
215 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr(
216 inp.dtype.element_ty == tl.bfloat16
217 ):
218 cdtype = tl.float32
219 else:
220 cdtype = inp.dtype.element_ty
222 # Map the program id to the row of inp it should compute.
223 pid = tle.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
224 inp = inp + pid * N
225 out = out + pid
226 row_mask = pid < M
228 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=cdtype)
229 for off in range(0, N, BLOCK_N):
230 cols = off + tl.arange(0, BLOCK_N)[None, :]
231 col_mask = cols < N
232 mask = row_mask and col_mask
234 a = tl.load(inp + cols, mask, other=0).to(cdtype)
235 _sum += a
236 sum = tl.sum(_sum, axis=1)[:, None]
237 tl.store(out, sum, row_mask)
240def sum_dim_comm(inp, dim=None, keepdim=False, *, dtype=None, out=None):
241 if dtype is None:
242 dtype = inp.dtype
243 if dtype is torch.bool:
244 dtype = torch.int64
246 if dim is None:
247 result = torch.sum(inp, dtype=dtype)
248 if keepdim:
249 result = result.reshape([1] * inp.ndim)
250 return result
252 if dim == []:
253 if not keepdim:
254 return sum(inp, dtype=dtype)
255 else:
256 dim_num = inp.ndim
257 return torch.reshape(sum(inp, dtype=dtype), [1] * dim_num)
259 shape = list(inp.shape)
260 dim = [d % inp.ndim for d in dim]
262 if len(dim) == 1:
263 dim = dim[0]
264 N = inp.shape[dim]
265 M = reduce(lambda x, y: x * y, shape[:dim], 1)
266 inp = inp.contiguous()
267 K = inp.numel() // M // N
268 shape[dim] = 1
269 if out is None:
270 out = torch.empty(shape, dtype=dtype, device=inp.device)
272 with torch_device_fn.device(inp.device):
273 if K > 1:
274 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1)
275 sum_dim_kernel_non_inner[grid](
276 out,
277 inp,
278 M,
279 N,
280 K,
281 )
282 else:
283 grid = (M, 1, 1)
284 sum_dim_kernel_inner[grid](
285 out,
286 inp,
287 M,
288 N,
289 )
290 if not keepdim:
291 out = out.squeeze(dim=dim)
292 return out
293 else:
294 inp = dim_compress(inp, dim)
295 N = 1
296 for i in dim:
297 N *= shape[i]
298 shape[i] = 1
299 M = inp.numel() // N
300 if out is None:
301 out = torch.empty(shape, dtype=dtype, device=inp.device)
303 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
304 with torch_device_fn.device(inp.device):
305 sum_dim_kernel[grid](inp, out, M, N)
306 if not keepdim:
307 out = out.squeeze(dim=dim)
308 return out
311def sum_dim(inp, dim=None, keepdim=False, *, dtype=None):
312 logger.debug("GEMS SUM_DIM")
313 return sum_dim_comm(inp, dim, keepdim, dtype=dtype)
316def sum_dim_out(inp, dim=None, keepdim=False, *, dtype=None, out):
317 logger.debug("GEMS SUM_DIM_OUT")
318 return sum_dim_comm(inp, dim, keepdim, dtype=dtype, out=out)