Coverage for src/flag_gems/runtime/backend/_aipu/ops/cumsum.py: 0%
258 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import device, torch_device_fn
9from flag_gems.utils import get_device_properties, libentry
10from flag_gems.utils import triton_lang_extension as tle
12device = device.name
13logger = logging.getLogger(__name__)
16@libentry()
17@triton.jit(do_not_specialize=["n_elements", "part_num"])
18def scan_part_sum_kernel(
19 inp,
20 out,
21 partial_sum,
22 n_elements,
23 part_num,
24 BLOCK_SIZE: tl.constexpr,
25):
26 pid = tle.program_id(0)
27 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
28 mask = offset < n_elements
30 inp_ptrs = inp + offset
31 inp_vals = tl.load(inp_ptrs, mask=mask)
32 if (
33 tl.constexpr(inp_vals.dtype.is_int64())
34 or tl.constexpr(inp_vals.dtype.is_uint64())
35 ) or tl.constexpr(inp_vals.dtype.is_fp64()):
36 inp_vals = inp_vals
37 elif tl.constexpr(inp_vals.dtype.is_int()):
38 inp_vals = inp_vals.to(tl.int32)
39 else:
40 inp_vals = inp_vals.to(tl.float32)
41 result = tl.cumsum(inp_vals, axis=0)
43 part_sum_via_sum = tl.sum(inp_vals)
45 out_ptrs = out + offset
46 tl.store(out_ptrs, result, mask=mask)
48 partial_sum_ptrs = partial_sum + pid
49 tl.store(partial_sum_ptrs, part_sum_via_sum)
52@libentry()
53@triton.jit(do_not_specialize=["n_elements", "part_num"])
54def add_base_sum_kernel(
55 out,
56 partial_sum,
57 n_elements,
58 part_num,
59 BLOCK_SIZE: tl.constexpr,
60):
61 pid = tle.program_id(0)
62 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
63 mask = offset < n_elements
65 out_ptrs = out + offset
66 out_vals = tl.load(out_ptrs, mask=mask)
68 if pid > 0:
69 partial_sum_ptrs = partial_sum + pid - 1
70 last_part_sum_via_sum = tl.load(partial_sum_ptrs)
72 final_vals = out_vals + last_part_sum_via_sum
73 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask)
76@libentry()
77@triton.jit(do_not_specialize=["part_num"])
78def scan_part_sum_abc_kernel(
79 inp,
80 out,
81 partial_sum,
82 B,
83 C,
84 part_num,
85 BLOCK_SIZE: tl.constexpr,
86):
87 pid_a = tle.program_id(0)
88 pid_b = tle.program_id(1)
89 pid_c = tle.program_id(2)
91 a_idx = pid_a
92 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
93 c_idx = pid_c
95 offset = a_idx * B * C + b_idx * C + c_idx
96 base_part_offset = a_idx * part_num * C + c_idx
97 part_offset = base_part_offset + pid_b * C
99 mask = b_idx < B
100 inp_ptrs = inp + offset
101 inp_vals = tl.load(inp_ptrs, mask=mask)
102 if (
103 tl.constexpr(inp_vals.dtype.is_int64())
104 or tl.constexpr(inp_vals.dtype.is_uint64())
105 ) or tl.constexpr(inp_vals.dtype.is_fp64()):
106 inp_vals = inp_vals
107 elif tl.constexpr(inp_vals.dtype.is_int()):
108 inp_vals = inp_vals.to(tl.int32)
109 else:
110 inp_vals = inp_vals.to(tl.float32)
111 result = tl.cumsum(inp_vals, axis=0)
113 part_sum_via_sum = tl.sum(inp_vals)
115 out_ptrs = out + offset
116 tl.store(out_ptrs, result, mask=mask)
118 partial_sum_ptrs = partial_sum + part_offset
119 tl.store(partial_sum_ptrs, part_sum_via_sum)
122@libentry()
123@triton.jit(do_not_specialize=["part_num"])
124def add_base_sum_abc_kernel(
125 out,
126 partial_sum,
127 B,
128 C,
129 part_num,
130 BLOCK_SIZE: tl.constexpr,
131):
132 pid_a = tle.program_id(0)
133 pid_b = tle.program_id(1)
134 pid_c = tle.program_id(2)
136 a_idx = pid_a
137 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
138 c_idx = pid_c
140 base_offset = a_idx * B * C + c_idx
141 offset = base_offset + b_idx * C
142 base_part_offset = a_idx * part_num * C + c_idx
143 last_part_offset = base_part_offset + (pid_b - 1) * C
145 mask = b_idx < B
146 out_ptrs = out + offset
147 out_vals = tl.load(out_ptrs, mask=mask)
149 if pid_b > 0:
150 partial_sum_ptrs = partial_sum + last_part_offset
151 last_part_sum_via_sum = tl.load(partial_sum_ptrs)
153 final_vals = out_vals + last_part_sum_via_sum
154 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask)
157def scan_then_fan_col(inp, out, n_ele, dtype):
158 # TODO(all): tune on target board
159 BLOCK_SIZE = 1024
160 if n_ele <= 1024 * 4:
161 BLOCK_SIZE = triton.next_power_of_2(n_ele)
162 part_num = math.ceil(n_ele / BLOCK_SIZE)
163 partial_sum = torch.empty(part_num, dtype=dtype, device=inp.device)
165 grid = (part_num,)
166 with torch_device_fn.device(inp.device):
167 scan_part_sum_kernel[grid](inp, out, partial_sum, n_ele, part_num, BLOCK_SIZE)
169 if part_num >= 2:
170 scan_then_fan_col(partial_sum, partial_sum, part_num, dtype)
171 with torch_device_fn.device(inp.device):
172 add_base_sum_kernel[grid](out, partial_sum, n_ele, part_num, BLOCK_SIZE)
175def scan_then_fan(inp, out, A, B, C, dtype):
176 # TODO(all): tune on target board
177 BLOCK_SIZE = 1024
178 if B <= 1024 * 4:
179 BLOCK_SIZE = triton.next_power_of_2(B)
180 part_num = math.ceil(B / BLOCK_SIZE)
181 partial_sum = torch.empty(A, part_num, C, dtype=dtype, device=inp.device)
183 grid = (A, part_num, C)
184 with torch_device_fn.device(inp.device):
185 scan_part_sum_abc_kernel[grid](
186 inp, out, partial_sum, B, C, part_num, BLOCK_SIZE
187 )
189 if part_num >= 2:
190 scan_then_fan(partial_sum, partial_sum, A, part_num, C, dtype)
191 with torch_device_fn.device(inp.device):
192 add_base_sum_abc_kernel[grid](out, partial_sum, B, C, part_num, BLOCK_SIZE)
195def cumsum_wrapper(inp, dim=1, dtype=None, out=None):
196 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
197 shape = inp.shape
198 dim = dim % inp.ndim
199 M = 1
200 N = shape[dim]
201 for i in range(dim):
202 M *= shape[i]
203 inp = inp.contiguous()
204 K = inp.numel() // M // N
206 if dtype is None:
207 dtype = inp.dtype
208 if dtype is torch.bool:
209 dtype = torch.int64
210 if out is None:
211 out = torch.empty_like(inp, dtype=dtype)
213 compute_dtype = out.dtype
214 if inp.dtype == torch.float16 or inp.dtype == torch.bfloat16:
215 compute_dtype = torch.float32
217 if M == 1 and K == 1:
218 scan_then_fan_col(inp, out, N, compute_dtype)
219 else:
220 scan_then_fan(inp, out, M, N, K, compute_dtype)
221 return out
224def cumsum(inp, dim=1, *, dtype=None):
225 logger.debug("GEMS CUMSUM")
226 return cumsum_wrapper(inp, dim, dtype)
229def cumsum_out(inp, dim=1, *, dtype=None, out):
230 logger.debug("GEMS CUMSUM_OUT")
231 return cumsum_wrapper(inp, dim, dtype, out)
234@libentry()
235@triton.jit(do_not_specialize=["K"])
236def normed_cumsum_kernel(inp, out, K, BLOCK: tl.constexpr):
237 row_start = tle.program_id(0) * K
238 row_off = tl.arange(0, BLOCK)
239 x = tl.load(inp + row_start + row_off, mask=row_off < K, other=0)
240 if x.dtype.is_fp16():
241 x = x.to(tl.float32)
242 y_sum = tl.sum(x, 0)
243 y = tl.cumsum(x, 0)
244 y = y / y_sum
245 tl.store(out + row_start + row_off, y, mask=row_off < K)
248@libentry()
249@triton.jit(
250 do_not_specialize=[
251 "r",
252 "t",
253 "R",
254 "K",
255 "r_stride",
256 "out_r_stride",
257 ]
258)
259def block_cumsum_kernel(
260 inp,
261 out,
262 sums,
263 r,
264 t,
265 R,
266 K,
267 r_stride,
268 k_stride,
269 out_r_stride,
270 out_k_stride,
271 OUTPUT_SUMS: tl.constexpr,
272 NORMALIZE: tl.constexpr,
273 HAS_OUT_LAYOUT: tl.constexpr,
274 TILE: tl.constexpr,
275):
276 # One CTA processes a (r, t*tile) chunk
277 # rows = [ grid.y, grid.y + r )
278 # cols = [ grid.x * t * tile, (grid.x + 1) * t * tile )
279 gridx = tle.program_id(0).to(tl.int64)
280 gridy = tle.program_id(1).to(tl.int64)
281 n_chunks = tle.num_programs(0)
283 for row in range(gridy * r, min((gridy + 1) * r, R)):
284 curr_cumsum = tl.zeros((1,), tl.float32)
285 row_offset = row * r_stride
286 cols = gridx * t * TILE + tl.arange(0, TILE)
287 for ti in range(0, t):
288 cols_offset = cols * k_stride
289 x = tl.load(inp + row_offset + cols_offset, mask=cols < K, other=0)
290 if x.dtype.is_fp16() | x.dtype.is_bf16():
291 x = x.to(tl.float32)
292 tile_sum = tl.sum(x, 0)[None]
293 tile_cumsum = tl.cumsum(x, 0) + curr_cumsum
294 curr_cumsum += tile_sum
295 if HAS_OUT_LAYOUT:
296 cols_offset = cols * out_k_stride
297 row_offset = row * out_r_stride
298 tl.store(out + row_offset + cols_offset, tile_cumsum, mask=cols < K)
299 if OUTPUT_SUMS:
300 tl.store(sums + row * n_chunks + gridx[None], curr_cumsum)
301 cols += TILE
302 if NORMALIZE:
303 cols = gridx * t * TILE + tl.arange(0, TILE)
304 for _ in range(0, t):
305 cols_offset = cols * k_stride
306 if HAS_OUT_LAYOUT:
307 cols_offset = cols * out_k_stride
308 row_offset = row * out_r_stride
309 x = tl.load(out + row_offset + cols_offset, mask=cols < K, other=0)
310 if x.dtype.is_fp16() | x.dtype.is_bf16():
311 x = x.to(tl.float32)
312 x = x / curr_cumsum
313 tl.store(out + row_offset + cols_offset, x, mask=cols < K)
314 cols += TILE
317@libentry()
318@triton.jit(
319 do_not_specialize=[
320 "r",
321 "t",
322 "R",
323 "K",
324 "r_stride",
325 "out_r_stride",
326 ]
327)
328def block_update_kernel(
329 inp,
330 base,
331 rscale_ptr,
332 out,
333 r,
334 t,
335 R,
336 K,
337 r_stride,
338 k_stride,
339 out_r_stride,
340 out_k_stride,
341 rscale_stride,
342 HAS_OUT_LAYOUT: tl.constexpr,
343 TILE: tl.constexpr,
344):
345 # One CTA processes a (r, t*tile) chunk
346 # rows = [ grid.y, grid.y + r )
347 # cols = [ grid.x * t * tile, (grid.x + 1) * t * tile )
348 gridx = tle.program_id(0).to(tl.int64)
349 gridy = tle.program_id(1).to(tl.int64)
350 n_gridx = tle.num_programs(1)
352 base += gridy * n_gridx + gridx
353 rscale_ptr += gridy * rscale_stride
355 for row in range(gridy, min(gridy + r, R)):
356 d = tl.load(base)
357 rscale = tl.load(rscale_ptr)
358 base += gridx
359 rscale_ptr += rscale_stride
360 row_offset = row * r_stride
361 cols = gridx * t * TILE + tl.arange(0, TILE)
362 for _ in range(0, t):
363 cols_offset = cols * k_stride
364 x = tl.load(inp + row_offset + cols_offset, mask=cols < K, other=0)
365 x += d
366 x /= rscale
367 if HAS_OUT_LAYOUT:
368 cols_offset = cols * out_k_stride
369 row_offset = row * out_r_stride
370 tl.store(out + row_offset + cols_offset, x, mask=cols < K)
371 cols += TILE
374GRID_Y_LIMIT = 65535
377def normed_cumsum(inp, dim=-1):
378 logger.debug("GEMS NORMED_CUMSUM")
379 assert inp.dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64)
380 dim = dim % inp.ndim
381 N = inp.numel()
382 K = inp.size(dim)
383 # inp = inp.contiguous()
384 # First and last dims are easier to handle, but transpose the middle dim to the last
385 ranked_dims = sorted(range(inp.ndim), key=lambda i: inp.stride(i), reverse=True)
386 is_mid_dim = dim not in (ranked_dims[0], ranked_dims[-1])
387 if is_mid_dim:
388 inp = inp.transpose(dim, -1).contiguous()
389 dim = -1
390 out = torch.empty_like(inp)
391 with torch_device_fn.device(inp.device.index):
392 # Pass one, scan a (batch, n_tiles * TILE) sized block within each cta
393 num_sms = get_device_properties(device).multi_processor_count
394 TILE = (
395 2048 if K >= 2048 else triton.next_power_of_2(K)
396 ) # TODO: _aipu changed TILE from 2048 to this
397 # Each row is split into n_chunks of chunks where each chunk is compised of
398 # n_tiles of tiles. Different chunks are assigned to different ctas.
399 n_rows = N // K
400 n_chunks = min(triton.cdiv(num_sms, n_rows), triton.cdiv(K, TILE))
401 n_tiles = triton.cdiv(triton.cdiv(K, TILE), n_chunks)
402 k_stride = inp.stride(dim)
403 r_stride = inp.size(dim) if k_stride == 1 else 1
404 if n_rows > GRID_Y_LIMIT:
405 batch = triton.cdiv(n_rows, GRID_Y_LIMIT)
406 n_batch = triton.cdiv(n_rows, batch)
407 else:
408 batch = 1
409 n_batch = n_rows
411 grid = (n_chunks, n_batch)
412 if n_chunks == 1:
413 block_cumsum_kernel[grid](
414 inp,
415 out,
416 0,
417 batch,
418 n_tiles,
419 n_rows,
420 K,
421 r_stride,
422 k_stride,
423 r_stride,
424 k_stride,
425 OUTPUT_SUMS=False,
426 NORMALIZE=True,
427 HAS_OUT_LAYOUT=False,
428 TILE=TILE,
429 )
430 return out
432 if inp.dtype != torch.float64:
433 acc_dtype = torch.float32
434 sums = torch.empty((n_rows, n_chunks), dtype=acc_dtype, device=device.name)
435 cumsums = torch.empty_like(sums)
436 block_cumsum_kernel[grid](
437 inp,
438 out,
439 sums,
440 batch,
441 n_tiles,
442 n_rows,
443 K,
444 r_stride,
445 k_stride,
446 r_stride,
447 k_stride,
448 OUTPUT_SUMS=True,
449 NORMALIZE=False,
450 HAS_OUT_LAYOUT=False,
451 TILE=TILE,
452 )
453 # Pass two, scan partial cumsums
454 block_cumsum_kernel[(1, n_batch)](
455 sums,
456 cumsums,
457 0,
458 batch,
459 1,
460 n_rows,
461 n_chunks,
462 n_chunks,
463 1,
464 n_chunks,
465 1,
466 OUTPUT_SUMS=False,
467 NORMALIZE=False,
468 HAS_OUT_LAYOUT=True,
469 TILE=TILE,
470 )
471 # print(sums)
472 rscale = cumsums[..., -1]
473 block_update_kernel[grid](
474 out,
475 cumsums - sums,
476 rscale,
477 out,
478 batch,
479 n_tiles,
480 n_rows,
481 K,
482 r_stride,
483 k_stride,
484 r_stride,
485 k_stride,
486 n_chunks,
487 HAS_OUT_LAYOUT=False,
488 TILE=TILE,
489 )
490 return out