Coverage for src/flag_gems/runtime/backend/_cambricon/ops/cumsum.py: 0%
325 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
1import copy
2import logging
3import math
5import torch
6import triton
7import triton.language as tl
9from flag_gems.runtime import device, torch_device_fn
10from flag_gems.utils import libentry, libtuner
12from ..utils import MAX_GRID_SIZE_Y, TOTAL_CORE_NUM
14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15device = device.name
17# FIXME(cambricon): double 8192 when JIRA:1488 is fixed
18MAX_C_MLU_CUMSUM = 8192
19MAX_C_MLU_SPILT_CUMSUM = 32768
20MAX_TILE_N = 256
23@triton.jit
24def cumsum_blelloch_impl(
25 in_block,
26 DTYPE: tl.constexpr,
27 BLOCK_M: tl.constexpr,
28 BLOCK_N: tl.constexpr,
29 BLOCK_K: tl.constexpr,
30 TILE_N: tl.constexpr,
31 TILE_NUM: tl.constexpr,
32):
33 x_block = tl.reshape(in_block, (BLOCK_M, TILE_NUM, TILE_N, BLOCK_K))
34 # Trans TILE_N and apply blelloch in TILE_N dim
35 x_block = tl.trans(x_block, 0, 2, 1, 3)
36 # Apply blelloch algo
37 # Up-Sweep Phase
38 step = 1
39 while step < TILE_N:
40 idx_a = step - 1
41 idx_b = idx_a + step
42 while idx_b < TILE_N:
43 x_block[:, idx_b, :, :] = x_block[:, idx_a, :, :] + x_block[:, idx_b, :, :]
44 idx_a += 2 * step
45 idx_b += 2 * step
46 step *= 2
47 # Down-Sweep Phase
48 step //= 2
49 while step > 0:
50 idx_b = TILE_N - 1 - step
51 idx_a = idx_b - step
52 while idx_a > 0:
53 x_block[:, idx_b, :, :] = x_block[:, idx_a, :, :] + x_block[:, idx_b, :, :]
54 idx_b -= 2 * step
55 idx_a -= 2 * step
56 step //= 2
57 # Deal the last tile row exclusive sum(Composed by right shift and tl.cumsum)
58 # Right shift 1 position for the last tile row
59 partial_sum = tl.zeros((BLOCK_M, TILE_NUM, BLOCK_K), dtype=tl.dtype(DTYPE))
60 if TILE_NUM > 1:
61 partial_sum[:, 1:, :] = x_block[:, TILE_N - 1, 0 : (TILE_NUM - 1), :]
62 partial_sum = tl.cumsum(partial_sum, axis=1)
63 # Apply cycle add for all tile data
64 x_block += partial_sum[:, None, :, :]
65 # Trans TILE_N dim to original pos
66 x_block = tl.trans(x_block, 0, 2, 1, 3)
67 x_block = tl.reshape(x_block, (BLOCK_M, BLOCK_N, BLOCK_K))
68 return x_block
71def config_prune(configs, named_args, **kwargs):
72 M = named_args["M"]
73 N = named_args["N"]
74 configs_map = {}
75 for config in configs:
76 kw = config.kwargs
77 BLOCK_M, BLOCK_N, TILE_N, num_warps, num_stages = (
78 kw["BLOCK_M"],
79 kw["BLOCK_N"],
80 kw["TILE_N"],
81 config.num_warps,
82 config.num_stages,
83 )
84 new_config = config
85 # When N is less than MAX_C_MLU_CUMSUM, no reduction loops. Unify different BLOCK_N configs.
86 if N <= MAX_C_MLU_CUMSUM:
87 # change config
88 new_config = copy.deepcopy(config)
89 BLOCK_N = new_config.kwargs["BLOCK_N"] = triton.next_power_of_2(N)
90 num_stages = new_config.num_stages = 1
91 else:
92 # When N is greater than MAX_C_MLU_CUMSUM, the pruning condition was obtained through experimentation.
93 # It may result in not finding the optimal solution.
94 if BLOCK_N < 2048:
95 continue
96 if BLOCK_N >= 2048 and TILE_N < 8:
97 continue
98 if (
99 BLOCK_N < MAX_C_MLU_CUMSUM
100 and BLOCK_M < M
101 and BLOCK_M <= (MAX_C_MLU_CUMSUM // BLOCK_N * 2)
102 ):
103 continue
104 # BLOCK_M can only be 1 when BLOCK_N is at its maximum
105 if BLOCK_N == MAX_C_MLU_CUMSUM and BLOCK_M > 1:
106 continue
107 # Prune invalid BLOCK_M
108 if BLOCK_M > M:
109 continue
110 # Prune invalid TILE_N
111 if TILE_N > BLOCK_N:
112 continue
113 # The pruning condition was obtained through experimentation. It may result in not finding the optimal solution.
114 if BLOCK_N > 128 and TILE_N < 8:
115 continue
116 key = (BLOCK_M, BLOCK_N, TILE_N, num_warps, num_stages)
117 # Only keep one config for the same key
118 configs_map.setdefault(key, new_config)
119 pruned_configs = []
120 for k, v in configs_map.items():
121 pruned_configs.append(v)
122 return pruned_configs
125@libentry()
126@libtuner(
127 configs=[
128 triton.Config(
129 {
130 "BLOCK_M": m,
131 "BLOCK_N": 2**n,
132 "TILE_N": 2**t,
133 },
134 num_stages=s,
135 num_warps=1,
136 )
137 for m in range(1, 20, 3)
138 for n in range(7, 13, 1)
139 for t in range(0, 7, 1)
140 for s in [1, 3]
141 ],
142 key=[
143 "M",
144 "N",
145 "K",
146 ],
147 strategy=["log", "log", "log"],
148 prune_configs_by={"early_config_prune": config_prune},
149)
150@triton.heuristics(
151 values={
152 "TILE_NUM": lambda args: args["BLOCK_N"] // args["TILE_N"]
153 if args["BLOCK_N"] % args["TILE_N"] == 0
154 and args["BLOCK_N"] // args["TILE_N"] >= 1
155 else 1,
156 "TILE_N": lambda args: args["BLOCK_N"]
157 if args["TILE_NUM"] == 1
158 else args["TILE_N"],
159 },
160)
161@triton.jit
162def cumsum_blelloch(
163 inp,
164 out,
165 M,
166 N,
167 K,
168 DTYPE: tl.constexpr,
169 BLOCK_M: tl.constexpr,
170 BLOCK_N: tl.constexpr,
171 TILE_N: tl.constexpr,
172 TILE_NUM: tl.constexpr,
173):
174 pid_m = tl.program_id(0)
175 pid_k = tl.program_id(1)
176 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
177 kep = tl.full([BLOCK_M, BLOCK_N, 1], float(0), tl.dtype(DTYPE))
178 for col_offset in range(0, N, BLOCK_N):
179 n_offset = col_offset + tl.arange(0, BLOCK_N)
180 # Pointers to the start of the row
181 offsets = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
182 mask = m_offset[:, None] < M and n_offset[None, :] < N
183 x_ptrs = inp + offsets
184 y_ptrs = out + offsets
186 # Load data into NRAM
187 in_block = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.dtype(DTYPE))
189 x_block = cumsum_blelloch_impl(
190 in_block, DTYPE, BLOCK_M, BLOCK_N, 1, TILE_N, TILE_NUM
191 )
192 # Add last block partial sum to current block
193 x_block = tl.reshape(x_block, (BLOCK_M, BLOCK_N))
194 kep_tmp = kep[:, BLOCK_N - 1, :]
195 x_block += kep_tmp
196 kep = x_block[:, :, None]
197 # Store result back to global memory
198 tl.store(y_ptrs, x_block, mask=mask)
201def get_reduction_dim_block_size(N):
202 block_size = N // TOTAL_CORE_NUM + ((N % TOTAL_CORE_NUM) != 0)
203 if block_size > MAX_C_MLU_SPILT_CUMSUM:
204 block_size = MAX_C_MLU_SPILT_CUMSUM
205 # In blelloch, block_size = TILE_N * TILE_NUM
206 # TILE_N and TILE_NUM should be power of 2, So is it
207 return triton.next_power_of_2(block_size)
210def config_prune_mid(configs, named_args, **kwargs):
211 M = named_args["M"]
212 K = named_args["K"]
213 BLOCK_N = named_args["BLOCK_N"]
214 configs_map = {}
215 for config in configs:
216 kw = config.kwargs
217 BLOCK_M, BLOCK_K, TILE_N, num_warps, num_stages = (
218 kw["BLOCK_M"],
219 kw["BLOCK_K"],
220 kw["TILE_N"],
221 config.num_warps,
222 config.num_stages,
223 )
224 new_config = config
225 # Prune invalid BLOCK_M
226 if BLOCK_M > M:
227 continue
228 # Prune invalid BLOCK_K
229 if BLOCK_K > K:
230 continue
231 if BLOCK_N * BLOCK_K * BLOCK_M > MAX_C_MLU_SPILT_CUMSUM:
232 continue
233 # Prune invalid TILE_N
234 if TILE_N > BLOCK_N:
235 continue
236 # The pruning condition was obtained through experimentation. It may result in not finding the optimal solution.
237 if BLOCK_N > 128 and TILE_N < 8:
238 continue
239 key = (BLOCK_M, BLOCK_N, BLOCK_K, TILE_N, num_warps, num_stages)
240 # Only keep one config for the same key
241 configs_map.setdefault(key, new_config)
242 pruned_configs = []
243 for k, v in configs_map.items():
244 pruned_configs.append(v)
245 return pruned_configs
248@libentry()
249@libtuner(
250 configs=[
251 triton.Config(
252 {
253 "BLOCK_M": m,
254 "BLOCK_K": 2**k,
255 "TILE_N": 2**t,
256 },
257 num_stages=s,
258 num_warps=1,
259 )
260 for m in range(1, 10, 3)
261 for k in range(0, 3, 1)
262 for t in range(5, int(math.log(MAX_TILE_N, 2) + 1), 1)
263 for s in [1, 3]
264 ],
265 key=[
266 "M",
267 "N",
268 "K",
269 "BLOCK_N",
270 ],
271 strategy=["log", "log", "log", "log"],
272 prune_configs_by={"early_config_prune": config_prune_mid},
273)
274@triton.heuristics(
275 values={
276 "TILE_NUM": lambda args: args["BLOCK_N"] // args["TILE_N"]
277 if args["BLOCK_N"] % args["TILE_N"] == 0
278 and args["BLOCK_N"] // args["TILE_N"] >= 1
279 else 1,
280 "TILE_N": lambda args: args["BLOCK_N"]
281 if args["TILE_NUM"] == 1
282 else args["TILE_N"],
283 },
284)
285@triton.jit
286def cumsum_kernel_mid(
287 inp,
288 out,
289 prefix_sum,
290 M,
291 N,
292 K,
293 BLOCK_N: tl.constexpr,
294 DTYPE: tl.constexpr,
295 BLOCK_M: tl.constexpr,
296 BLOCK_K: tl.constexpr,
297 TILE_N: tl.constexpr,
298 TILE_NUM: tl.constexpr,
299):
300 pid_m = tl.program_id(0)
301 pid_n = tl.program_id(1)
302 num_jobs_n = tl.num_programs(1)
303 pid_k = tl.program_id(2)
304 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
305 k_offset = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
306 n_offset = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
307 offsets = (
308 m_offset[:, None, None] * N * K
309 + n_offset[
310 None,
311 :,
312 None,
313 ]
314 * K
315 + k_offset[None, None, :]
316 )
317 mask = (m_offset[:, None, None] < M and n_offset[None, :, None] < N) and k_offset[
318 None, None, :
319 ] < K
320 x_ptrs = inp + offsets
321 y_ptrs = out + offsets
323 # Load data into NRAM
324 in_block = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.dtype(DTYPE))
326 x_block = cumsum_blelloch_impl(
327 in_block, DTYPE, BLOCK_M, BLOCK_N, BLOCK_K, TILE_N, TILE_NUM
328 )
329 tl.store(y_ptrs, x_block, mask=mask)
330 prefix_sum_offsets = (
331 m_offset[:, None] * num_jobs_n * K + pid_n * K + k_offset[None, :]
332 )
333 prefix_sum_mask = m_offset[:, None] < M and k_offset[None, :] < K
334 prefix_sum_ptrs = prefix_sum + prefix_sum_offsets
335 tl.store(prefix_sum_ptrs, x_block[:, BLOCK_N - 1, :], prefix_sum_mask)
338@libentry()
339@libtuner(
340 configs=[
341 triton.Config(
342 {
343 "BLOCK_M": m,
344 "BLOCK_K": 2**k,
345 },
346 num_stages=s,
347 num_warps=1,
348 )
349 for m in [1, 3, 6]
350 for k in range(0, 3, 1)
351 for s in [1, 3]
352 ],
353 key=[
354 "M",
355 "N",
356 "K",
357 "BLOCK_N",
358 ],
359 strategy=["log", "log", "log", "log"],
360)
361@triton.jit
362def cumsum_kernel_result(
363 inp,
364 prefix_sum,
365 out,
366 M,
367 N,
368 K,
369 BLOCK_N: tl.constexpr,
370 DTYPE: tl.constexpr,
371 BLOCK_M: tl.constexpr,
372 BLOCK_K: tl.constexpr,
373):
374 pid_m = tl.program_id(0)
375 pid_n = tl.program_id(1)
377 num_jobs_n = tl.num_programs(1)
378 pid_k = tl.program_id(2)
379 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
380 n_offset = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
381 k_offset = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
382 offsets = (
383 m_offset[:, None, None] * N * K
384 + n_offset[
385 None,
386 :,
387 None,
388 ]
389 * K
390 + k_offset[None, None, :]
391 )
392 mask = (m_offset[:, None, None] < M and n_offset[None, :, None] < N) and k_offset[
393 None, None, :
394 ] < K
395 x_ptrs = inp + offsets
396 y_ptrs = out + offsets
398 # Load data into NRAM
399 x_block = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.dtype(DTYPE))
401 if pid_n > 0:
402 sum_offsets = (
403 m_offset[:, None] * num_jobs_n * K + (pid_n - 1) * K + k_offset[None, :]
404 )
405 sum_mask = m_offset[:, None] < M and k_offset[None, :] < K
406 sum_ptrs = prefix_sum + sum_offsets
407 sum_block = tl.load(sum_ptrs, mask=sum_mask, other=0.0).to(tl.dtype(DTYPE))
408 x_block += sum_block[:, None, :]
410 # Store result back to global memory
411 tl.store(y_ptrs, x_block, mask=mask)
414def cumsum_wrapper(inp, dim=1, dtype=None, out=None):
415 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
416 shape = inp.shape
417 dim = dim % inp.ndim
418 M = 1
419 N = shape[dim]
420 for i in range(dim):
421 M *= shape[i]
422 inp = inp.contiguous()
423 K = inp.numel() // M // N
425 if dtype is None:
426 dtype = inp.dtype
427 if dtype is torch.bool:
428 dtype = torch.int32
429 if out is None:
430 out = torch.empty_like(inp, dtype=dtype)
432 blelloch_grid = lambda meta: (
433 triton.cdiv(M, meta["BLOCK_M"]),
434 K,
435 )
437 dtypestr = "fp32" if torch.is_floating_point(out) else "int32"
438 if (M * K < TOTAL_CORE_NUM / 2) and (N > MAX_C_MLU_CUMSUM):
439 # result BLOCK_N must be same as mid BLOCK_N
440 mid_out = torch.empty_like(inp, dtype=dtype)
441 BLOCK_N = get_reduction_dim_block_size(N)
442 prefix_sum_inp = torch.empty(
443 M, triton.cdiv(N, BLOCK_N), K, dtype=dtype, device=inp.device
444 )
445 prefix_sum = torch.empty(
446 M, triton.cdiv(N, BLOCK_N), K, dtype=dtype, device=inp.device
447 )
448 grid = lambda meta: (
449 triton.cdiv(M, meta["BLOCK_M"]),
450 triton.cdiv(N, BLOCK_N),
451 triton.cdiv(K, meta["BLOCK_K"]),
452 )
453 with torch_device_fn.device(inp.device):
454 cumsum_kernel_mid[grid](
455 inp, mid_out, prefix_sum_inp, M, N, K, BLOCK_N, dtypestr
456 )
457 cumsum_blelloch[blelloch_grid](
458 prefix_sum_inp, prefix_sum, M, triton.cdiv(N, BLOCK_N), K, dtypestr
459 )
460 cumsum_kernel_result[grid](
461 mid_out, prefix_sum, out, M, N, K, BLOCK_N, dtypestr
462 )
463 else:
464 with torch_device_fn.device(inp.device):
465 cumsum_blelloch[blelloch_grid](inp, out, M, N, K, dtypestr)
466 return out
469def cumsum(inp, dim=1, *, dtype=None):
470 logger.debug("GEMS_CAMBRICON CUMSUM")
471 return cumsum_wrapper(inp, dim, dtype)
474def cumsum_out(inp, dim=1, *, dtype=None, out):
475 logger.debug("GEMS_CAMBRICON CUMSUM_OUT")
476 return cumsum_wrapper(inp, dim, dtype, out)
479@libentry()
480@triton.jit(do_not_specialize=["K"])
481def normed_cumsum_kernel(inp, out, K, BLOCK: tl.constexpr):
482 row_start = tl.program_id(0) * K
483 row_off = tl.arange(0, BLOCK)
484 x = tl.load(inp + row_start + row_off, mask=row_off < K, other=0)
485 if x.dtype.is_fp16():
486 x = x.to(tl.float32)
487 y_sum = tl.sum(x, 0)
488 y = tl.cumsum(x, 0)
489 y = y / y_sum
490 tl.store(out + row_start + row_off, y, mask=row_off < K)
493@libentry()
494@triton.jit(
495 do_not_specialize=[
496 "r",
497 "t",
498 "R",
499 "K",
500 "r_stride",
501 "out_r_stride",
502 ]
503)
504def block_cumsum_kernel(
505 inp,
506 out,
507 sums,
508 r,
509 t,
510 R,
511 K,
512 r_stride,
513 k_stride,
514 out_r_stride,
515 out_k_stride,
516 OUTPUT_SUMS: tl.constexpr,
517 NORMALIZE: tl.constexpr,
518 HAS_OUT_LAYOUT: tl.constexpr,
519 TILE: tl.constexpr,
520):
521 # One CTA processes a (r, t*tile) chunk
522 # rows = [ grid.y, grid.y + r )
523 # cols = [ grid.x * t * tile, (grid.x + 1) * t * tile )
524 gridx = tl.program_id(0).to(tl.int64)
525 gridy = tl.program_id(1).to(tl.int64)
526 n_chunks = tl.num_programs(0)
528 for row in range(gridy * r, min((gridy + 1) * r, R)):
529 curr_cumsum = tl.zeros((1,), tl.float32)
530 row_offset = row * r_stride
531 cols = gridx * t * TILE + tl.arange(0, TILE)
532 for ti in range(0, t):
533 cols_offset = cols * k_stride
534 x = tl.load(inp + row_offset + cols_offset, mask=cols < K, other=0)
535 if x.dtype.is_fp16() | x.dtype.is_bf16():
536 x = x.to(tl.float32)
537 tile_sum = tl.sum(x, 0)[None]
538 tile_cumsum = tl.cumsum(x, 0) + curr_cumsum
539 curr_cumsum += tile_sum
540 if HAS_OUT_LAYOUT:
541 cols_offset = cols * out_k_stride
542 row_offset = row * out_r_stride
543 tl.store(out + row_offset + cols_offset, tile_cumsum, mask=cols < K)
544 if OUTPUT_SUMS:
545 tl.store(sums + row * n_chunks + gridx[None], curr_cumsum)
546 cols += TILE
547 if NORMALIZE:
548 cols = gridx * t * TILE + tl.arange(0, TILE)
549 for _ in range(0, t):
550 cols_offset = cols * k_stride
551 if HAS_OUT_LAYOUT:
552 cols_offset = cols * out_k_stride
553 row_offset = row * out_r_stride
554 x = tl.load(out + row_offset + cols_offset, mask=cols < K, other=0)
555 if x.dtype.is_fp16() | x.dtype.is_bf16():
556 x = x.to(tl.float32)
557 x = x / curr_cumsum
558 tl.store(out + row_offset + cols_offset, x, mask=cols < K)
559 cols += TILE
562@libentry()
563@triton.jit(
564 do_not_specialize=[
565 "r",
566 "t",
567 "R",
568 "K",
569 "r_stride",
570 "out_r_stride",
571 ]
572)
573def block_update_kernel(
574 inp,
575 base,
576 rscale_ptr,
577 out,
578 r,
579 t,
580 R,
581 K,
582 r_stride,
583 k_stride,
584 out_r_stride,
585 out_k_stride,
586 rscale_stride,
587 HAS_OUT_LAYOUT: tl.constexpr,
588 TILE: tl.constexpr,
589):
590 # One CTA processes a (r, t*tile) chunk
591 # rows = [ grid.y, grid.y + r )
592 # cols = [ grid.x * t * tile, (grid.x + 1) * t * tile )
593 gridx = tl.program_id(0).to(tl.int64)
594 gridy = tl.program_id(1).to(tl.int64)
595 n_gridx = tl.num_programs(1)
597 base += gridy * n_gridx + gridx
598 rscale_ptr += gridy * rscale_stride
600 for row in range(gridy, min(gridy + r, R)):
601 d = tl.load(base)
602 rscale = tl.load(rscale_ptr)
603 base += gridx
604 rscale_ptr += rscale_stride
605 row_offset = row * r_stride
606 cols = gridx * t * TILE + tl.arange(0, TILE)
607 for _ in range(0, t):
608 cols_offset = cols * k_stride
609 x = tl.load(inp + row_offset + cols_offset, mask=cols < K, other=0)
610 x += d
611 x /= rscale
612 if HAS_OUT_LAYOUT:
613 cols_offset = cols * out_k_stride
614 row_offset = row * out_r_stride
615 tl.store(out + row_offset + cols_offset, x, mask=cols < K)
616 cols += TILE
619GRID_Y_LIMIT = MAX_GRID_SIZE_Y
622def normed_cumsum(inp, dim=-1):
623 logger.debug("GEMS_CAMBRICON NORMED_CUMSUM")
624 assert inp.dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64)
625 dim = dim % inp.ndim
626 N = inp.numel()
627 K = inp.size(dim)
628 # inp = inp.contiguous()
629 # First and last dims are easier to handle, but transpose the middle dim to the last
630 ranked_dims = sorted(range(inp.ndim), key=lambda i: inp.stride(i), reverse=True)
631 is_mid_dim = dim not in (ranked_dims[0], ranked_dims[-1])
632 if is_mid_dim:
633 inp = inp.transpose(dim, -1).contiguous()
634 dim = -1
635 out = torch.empty_like(inp)
636 with torch_device_fn.device(inp.device.index):
637 # Pass one, scan a (batch, n_tiles * TILE) sized block within each cta
638 num_sms = TOTAL_CORE_NUM # torch.cuda.get_device_properties("cuda").multi_processor_count
639 TILE = 2048
640 # Each row is split into n_chunks of chunks where each chunk is compised of
641 # n_tiles of tiles. Different chunks are assigned to different ctas.
642 n_rows = N // K
643 n_chunks = min(triton.cdiv(num_sms, n_rows), triton.cdiv(K, TILE))
644 n_tiles = triton.cdiv(triton.cdiv(K, TILE), n_chunks)
645 k_stride = inp.stride(dim)
646 r_stride = inp.size(dim) if k_stride == 1 else 1
647 if n_rows > GRID_Y_LIMIT:
648 batch = triton.cdiv(n_rows, GRID_Y_LIMIT)
649 n_batch = triton.cdiv(n_rows, batch)
650 else:
651 batch = 1
652 n_batch = n_rows
654 grid = (n_chunks, n_batch)
655 if n_chunks == 1:
656 block_cumsum_kernel[grid](
657 inp,
658 out,
659 0,
660 batch,
661 n_tiles,
662 n_rows,
663 K,
664 r_stride,
665 k_stride,
666 r_stride,
667 k_stride,
668 OUTPUT_SUMS=False,
669 NORMALIZE=True,
670 HAS_OUT_LAYOUT=False,
671 TILE=TILE,
672 )
673 return out
675 if inp.dtype != torch.float64:
676 acc_dtype = torch.float32
677 sums = torch.empty((n_rows, n_chunks), dtype=acc_dtype, device=device.name)
678 cumsums = torch.empty_like(sums)
679 block_cumsum_kernel[grid](
680 inp,
681 out,
682 sums,
683 batch,
684 n_tiles,
685 n_rows,
686 K,
687 r_stride,
688 k_stride,
689 r_stride,
690 k_stride,
691 OUTPUT_SUMS=True,
692 NORMALIZE=False,
693 HAS_OUT_LAYOUT=False,
694 TILE=TILE,
695 )
696 # Pass two, scan partial cumsums
697 block_cumsum_kernel[(1, n_batch)](
698 sums,
699 cumsums,
700 0,
701 batch,
702 1,
703 n_rows,
704 n_chunks,
705 n_chunks,
706 1,
707 n_chunks,
708 1,
709 OUTPUT_SUMS=False,
710 NORMALIZE=False,
711 HAS_OUT_LAYOUT=True,
712 TILE=TILE,
713 )
714 # print(sums)
715 rscale = cumsums[..., -1]
716 block_update_kernel[grid](
717 out,
718 cumsums - sums,
719 rscale,
720 out,
721 batch,
722 n_tiles,
723 n_rows,
724 K,
725 r_stride,
726 k_stride,
727 r_stride,
728 k_stride,
729 n_chunks,
730 HAS_OUT_LAYOUT=False,
731 TILE=TILE,
732 )
733 return out