Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/cumsum.py: 0%
269 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
1import logging
2import math
3import os
5import torch
6import triton
7import triton.language as tl
8from torch._prims_common import is_boolean_dtype, is_integer_dtype
10from flag_gems.runtime import device, torch_device_fn
11from flag_gems.utils import libentry
12from flag_gems.utils import triton_lang_extension as tle
14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15device = device.name
18@libentry()
19@triton.jit(do_not_specialize=["n_elements", "part_num"])
20def scan_part_sum_kernel(
21 inp,
22 out,
23 partial_sum,
24 n_elements,
25 part_num,
26 BLOCK_SIZE: tl.constexpr,
27):
28 pid = tle.program_id(0)
29 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
30 mask = offset < n_elements
32 inp_ptrs = inp + offset
33 inp_vals = tl.load(inp_ptrs, mask=mask)
34 if (
35 tl.constexpr(inp_vals.dtype.is_int64())
36 or tl.constexpr(inp_vals.dtype.is_uint64())
37 ) or tl.constexpr(inp_vals.dtype.is_fp64()):
38 inp_vals = inp_vals
39 elif tl.constexpr(inp_vals.dtype.is_int()):
40 inp_vals = inp_vals.to(tl.int32)
41 else:
42 inp_vals = inp_vals.to(tl.float32)
43 result = tl.cumsum(inp_vals, axis=0)
45 part_sum_via_sum = tl.sum(inp_vals)
47 out_ptrs = out + offset
48 tl.store(out_ptrs, result, mask=mask)
50 partial_sum_ptrs = partial_sum + pid
51 tl.store(partial_sum_ptrs, part_sum_via_sum)
54@libentry()
55@triton.jit(do_not_specialize=["n_elements", "part_num"])
56def add_base_sum_kernel(
57 out,
58 partial_sum,
59 n_elements,
60 part_num,
61 BLOCK_SIZE: tl.constexpr,
62):
63 pid = tle.program_id(0)
64 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
65 mask = offset < n_elements
67 out_ptrs = out + offset
68 out_vals = tl.load(out_ptrs, mask=mask)
70 if pid > 0:
71 partial_sum_ptrs = partial_sum + pid - 1
72 last_part_sum_via_sum = tl.load(partial_sum_ptrs)
74 final_vals = out_vals + last_part_sum_via_sum
75 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask)
78@libentry()
79@triton.jit(do_not_specialize=["part_num"])
80def scan_part_sum_abc_kernel(
81 inp,
82 out,
83 partial_sum,
84 B,
85 C,
86 part_num,
87 BLOCK_SIZE: tl.constexpr,
88):
89 pid_a = tle.program_id(0)
90 pid_b = tle.program_id(1)
91 pid_c = tle.program_id(2)
93 a_idx = pid_a
94 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
95 c_idx = pid_c
97 offset = a_idx * B * C + b_idx * C + c_idx
98 base_part_offset = a_idx * part_num * C + c_idx
99 part_offset = base_part_offset + pid_b * C
101 mask = b_idx < B
102 inp_ptrs = inp + offset
103 inp_vals = tl.load(inp_ptrs, mask=mask)
104 if (
105 tl.constexpr(inp_vals.dtype.is_int64())
106 or tl.constexpr(inp_vals.dtype.is_uint64())
107 ) or tl.constexpr(inp_vals.dtype.is_fp64()):
108 inp_vals = inp_vals
109 elif tl.constexpr(inp_vals.dtype.is_int()):
110 inp_vals = inp_vals.to(tl.int32)
111 else:
112 inp_vals = inp_vals.to(tl.float32)
113 result = tl.cumsum(inp_vals, axis=0)
115 part_sum_via_sum = tl.sum(inp_vals)
117 offset = tl.where(mask, offset, -1)
118 out_ptrs = out + offset
119 tl.store(out_ptrs, result, mask=mask)
121 partial_sum_ptrs = partial_sum + part_offset
122 tl.store(partial_sum_ptrs, part_sum_via_sum)
125@libentry()
126@triton.jit(do_not_specialize=["part_num"])
127def add_base_sum_abc_kernel(
128 out,
129 partial_sum,
130 B,
131 C,
132 part_num,
133 BLOCK_SIZE: tl.constexpr,
134):
135 pid_a = tle.program_id(0)
136 pid_b = tle.program_id(1)
137 pid_c = tle.program_id(2)
139 a_idx = pid_a
140 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
141 c_idx = pid_c
143 base_offset = a_idx * B * C + c_idx
144 offset = base_offset + b_idx * C
145 base_part_offset = a_idx * part_num * C + c_idx
146 last_part_offset = base_part_offset + (pid_b - 1) * C
148 mask = b_idx < B
149 out_ptrs = out + offset
150 out_vals = tl.load(out_ptrs, mask=mask)
152 if pid_b > 0:
153 partial_sum_ptrs = partial_sum + last_part_offset
154 last_part_sum_via_sum = tl.load(partial_sum_ptrs)
156 final_vals = out_vals + last_part_sum_via_sum
157 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask)
160def scan_then_fan_col(inp, out, n_ele, dtype):
161 # TODO(all): tune on target board
162 BLOCK_SIZE = 1024
163 if n_ele <= 1024 * 4:
164 BLOCK_SIZE = triton.next_power_of_2(n_ele)
165 part_num = math.ceil(n_ele / BLOCK_SIZE)
166 partial_sum = torch.empty(part_num, dtype=dtype, device=inp.device)
168 grid = (part_num,)
169 with torch_device_fn.device(inp.device):
170 scan_part_sum_kernel[grid](inp, out, partial_sum, n_ele, part_num, BLOCK_SIZE)
172 if part_num >= 2:
173 scan_then_fan_col(partial_sum, partial_sum, part_num, dtype)
174 with torch_device_fn.device(inp.device):
175 add_base_sum_kernel[grid](out, partial_sum, n_ele, part_num, BLOCK_SIZE)
178def scan_then_fan(inp, out, A, B, C, dtype):
179 # TODO(all): tune on target board
180 BLOCK_SIZE = 1024
181 if B <= 1024 * 4:
182 BLOCK_SIZE = triton.next_power_of_2(B)
183 part_num = math.ceil(B / BLOCK_SIZE)
184 partial_sum = torch.empty(A, part_num, C, dtype=dtype, device=inp.device)
186 grid = (A, part_num, C)
188 if inp.shape[1] > 8192:
189 os.environ["TRITONXPU_OTHER_SIM"] = "1"
190 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1"
191 scan_part_sum_abc_kernel[grid](
192 inp, out, partial_sum, B, C, part_num, BLOCK_SIZE
193 )
195 if "TRITONXPU_OTHER_SIM" in os.environ:
196 del os.environ["TRITONXPU_OTHER_SIM"]
197 if "TRITONXPU_STORE_MASK_SIM" in os.environ:
198 del os.environ["TRITONXPU_STORE_MASK_SIM"]
200 else:
201 with torch_device_fn.device(inp.device):
202 scan_part_sum_abc_kernel[grid](
203 inp, out, partial_sum, B, C, part_num, BLOCK_SIZE
204 )
206 if part_num >= 2:
207 scan_then_fan(partial_sum, partial_sum, A, part_num, C, dtype)
208 with torch_device_fn.device(inp.device):
209 add_base_sum_abc_kernel[grid](out, partial_sum, B, C, part_num, BLOCK_SIZE)
212def cumsum_wrapper(inp, dim=1, dtype=None, out=None):
213 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
214 shape = inp.shape
215 dim = dim % inp.ndim
216 M = 1
217 N = shape[dim]
218 for i in range(dim):
219 M *= shape[i]
220 inp = inp.contiguous()
221 K = inp.numel() // M // N
223 if dtype is None:
224 dtype = inp.dtype
225 if is_integer_dtype(dtype) or is_boolean_dtype(dtype):
226 dtype = torch.int64
227 if out is None:
228 out = torch.empty_like(inp, dtype=dtype)
230 compute_dtype = out.dtype
231 if inp.dtype == torch.float16 or inp.dtype == torch.bfloat16:
232 compute_dtype = torch.float32
234 if M == 1 and K == 1:
235 scan_then_fan_col(inp, out, N, compute_dtype)
236 else:
237 scan_then_fan(inp, out, M, N, K, compute_dtype)
238 return out
241def cumsum(inp, dim=1, *, dtype=None):
242 logger.debug("GEMS CUMSUM")
243 return cumsum_wrapper(inp, dim, dtype)
246def cumsum_out(inp, dim=1, *, dtype=None, out):
247 logger.debug("GEMS CUMSUM_OUT")
248 return cumsum_wrapper(inp, dim, dtype, out)
251@libentry()
252@triton.jit(do_not_specialize=["K"])
253def normed_cumsum_kernel(inp, out, K, BLOCK: tl.constexpr):
254 row_start = tle.program_id(0) * K
255 row_off = tl.arange(0, BLOCK)
256 x = tl.load(inp + row_start + row_off, mask=row_off < K, other=0)
257 if x.dtype.is_fp16():
258 x = x.to(tl.float32)
259 y_sum = tl.sum(x, 0)
260 y = tl.cumsum(x, 0)
261 y = y / y_sum
262 tl.store(out + row_start + row_off, y, mask=row_off < K)
265@libentry()
266@triton.jit(
267 do_not_specialize=[
268 "r",
269 "t",
270 "R",
271 "K",
272 "r_stride",
273 "out_r_stride",
274 ]
275)
276def block_cumsum_kernel(
277 inp,
278 out,
279 sums,
280 r: tl.constexpr,
281 t: tl.constexpr,
282 R: tl.constexpr,
283 K: tl.constexpr,
284 r_stride: tl.constexpr,
285 k_stride: tl.constexpr,
286 out_r_stride: tl.constexpr,
287 out_k_stride: tl.constexpr,
288 OUTPUT_SUMS: tl.constexpr,
289 NORMALIZE: tl.constexpr,
290 HAS_OUT_LAYOUT: tl.constexpr,
291 TILE: tl.constexpr,
292):
293 # One CTA processes a (r, t*tile) chunk
294 # rows = [ grid.y, grid.y + r )
295 # cols = [ grid.x * t * tile, (grid.x + 1) * t * tile )
296 gridx = tle.program_id(0).to(tl.int64)
297 gridy = tle.program_id(1).to(tl.int64)
298 n_chunks = tle.num_programs(0)
300 for row in range(gridy * r, min((gridy + 1) * r, R)):
301 curr_cumsum = tl.zeros((1,), tl.float32)
302 row_offset = row * r_stride
303 cols = gridx * t * TILE + tl.arange(0, TILE)
304 for ti in range(0, t):
305 cols_offset = cols * k_stride
306 x = tl.load(inp + row_offset + cols_offset, mask=cols < K, other=0)
307 if x.dtype.is_fp16() | x.dtype.is_bf16():
308 x = x.to(tl.float32)
309 tile_sum = tl.sum(x, 0)[None]
310 tile_cumsum = tl.cumsum(x, 0) + curr_cumsum
311 curr_cumsum += tile_sum
312 if HAS_OUT_LAYOUT:
313 cols_offset = cols * out_k_stride
314 row_offset = row * out_r_stride
315 tl.store(out + row_offset + cols_offset, tile_cumsum, mask=cols < K)
316 if OUTPUT_SUMS:
317 tl.store(sums + row * n_chunks + gridx[None], curr_cumsum)
318 cols += TILE
319 if NORMALIZE:
320 cols = gridx * t * TILE + tl.arange(0, TILE)
321 for _ in range(0, t):
322 cols_offset = cols * k_stride
323 if HAS_OUT_LAYOUT:
324 cols_offset = cols * out_k_stride
325 row_offset = row * out_r_stride
326 x = tl.load(out + row_offset + cols_offset, mask=cols < K, other=0)
327 if x.dtype.is_fp16() | x.dtype.is_bf16():
328 x = x.to(tl.float32)
329 x = x / curr_cumsum
330 tl.store(out + row_offset + cols_offset, x, mask=cols < K)
331 cols += TILE
334@libentry()
335@triton.jit(
336 do_not_specialize=[
337 "r",
338 "t",
339 "R",
340 "K",
341 "r_stride",
342 "out_r_stride",
343 ]
344)
345def block_update_kernel(
346 inp,
347 base,
348 rscale_ptr,
349 out,
350 r,
351 t,
352 R,
353 K,
354 r_stride,
355 k_stride,
356 out_r_stride,
357 out_k_stride,
358 rscale_stride,
359 HAS_OUT_LAYOUT: tl.constexpr,
360 TILE: tl.constexpr,
361):
362 # One CTA processes a (r, t*tile) chunk
363 # rows = [ grid.y, grid.y + r )
364 # cols = [ grid.x * t * tile, (grid.x + 1) * t * tile )
365 gridx = tle.program_id(0).to(tl.int64)
366 gridy = tle.program_id(1).to(tl.int64)
367 n_gridx = tle.num_programs(1)
369 base += gridy * n_gridx + gridx
370 rscale_ptr += gridy * rscale_stride
372 for row in range(gridy, min(gridy + r, R)):
373 d = tl.load(base)
374 rscale = tl.load(rscale_ptr)
375 base += gridx
376 rscale_ptr += rscale_stride
377 row_offset = row * r_stride
378 cols = gridx * t * TILE + tl.arange(0, TILE)
379 for _ in range(0, t):
380 cols_offset = cols * k_stride
381 x = tl.load(inp + row_offset + cols_offset, mask=cols < K, other=0)
382 x += d
383 x /= rscale
384 if HAS_OUT_LAYOUT:
385 cols_offset = cols * out_k_stride
386 row_offset = row * out_r_stride
387 tl.store(out + row_offset + cols_offset, x, mask=cols < K)
388 cols += TILE
391GRID_Y_LIMIT = 65535
394def normed_cumsum(inp, dim=-1):
395 logger.debug("GEMS NORMED_CUMSUM")
396 assert inp.dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64)
397 dim = dim % inp.ndim
398 N = inp.numel()
399 K = inp.size(dim)
400 # inp = inp.contiguous()
401 # First and last dims are easier to handle, but transpose the middle dim to the last
402 ranked_dims = sorted(range(inp.ndim), key=lambda i: inp.stride(i), reverse=True)
403 is_mid_dim = dim not in (ranked_dims[0], ranked_dims[-1])
404 if is_mid_dim:
405 inp = inp.transpose(dim, -1).contiguous()
406 dim = -1
407 out = torch.empty_like(inp)
408 with torch_device_fn.device(inp.device.index):
409 # Pass one, scan a (batch, n_tiles * TILE) sized block within each cta
410 num_sms = torch_device_fn.get_device_properties(device).multi_processor_count
411 TILE = 8192
412 # Each row is split into n_chunks of chunks where each chunk is compised of
413 # n_tiles of tiles. Different chunks are assigned to different ctas.
414 n_rows = N // K
415 n_chunks = min(triton.cdiv(num_sms, n_rows), triton.cdiv(K, TILE))
416 n_tiles = triton.cdiv(triton.cdiv(K, TILE), n_chunks)
417 k_stride = inp.stride(dim)
418 r_stride = inp.size(dim) if k_stride == 1 else 1
419 if n_rows > GRID_Y_LIMIT:
420 batch = triton.cdiv(n_rows, GRID_Y_LIMIT)
421 n_batch = triton.cdiv(n_rows, batch)
422 else:
423 batch = 1
424 n_batch = n_rows
426 grid = (n_chunks, n_batch)
427 if n_chunks == 1:
428 block_cumsum_kernel[grid](
429 inp,
430 out,
431 0,
432 batch,
433 n_tiles,
434 n_rows,
435 K,
436 r_stride,
437 k_stride,
438 r_stride,
439 k_stride,
440 OUTPUT_SUMS=False,
441 NORMALIZE=True,
442 HAS_OUT_LAYOUT=False,
443 TILE=TILE,
444 isCloseUnrollControl=True,
445 )
446 return out
448 if inp.dtype != torch.float64:
449 acc_dtype = torch.float32
450 sums = torch.empty((n_rows, n_chunks), dtype=acc_dtype, device=device.name)
451 cumsums = torch.empty_like(sums)
452 block_cumsum_kernel[grid](
453 inp,
454 out,
455 sums,
456 batch,
457 n_tiles,
458 n_rows,
459 K,
460 r_stride,
461 k_stride,
462 r_stride,
463 k_stride,
464 OUTPUT_SUMS=True,
465 NORMALIZE=False,
466 HAS_OUT_LAYOUT=False,
467 TILE=TILE,
468 isCloseUnrollControl=True,
469 )
470 # Pass two, scan partial cumsums
471 block_cumsum_kernel[(1, n_batch)](
472 sums,
473 cumsums,
474 0,
475 batch,
476 1,
477 n_rows,
478 n_chunks,
479 n_chunks,
480 1,
481 n_chunks,
482 1,
483 OUTPUT_SUMS=False,
484 NORMALIZE=False,
485 HAS_OUT_LAYOUT=True,
486 TILE=TILE,
487 isCloseUnrollControl=True,
488 )
489 # print(sums)
490 rscale = cumsums[..., -1]
491 block_update_kernel[grid](
492 out,
493 cumsums - sums,
494 rscale,
495 out,
496 batch,
497 n_tiles,
498 n_rows,
499 K,
500 r_stride,
501 k_stride,
502 r_stride,
503 k_stride,
504 n_chunks,
505 HAS_OUT_LAYOUT=False,
506 TILE=TILE,
507 )
508 return out