Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/unique.py: 0%
524 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1import os
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import triton_lang_extension as tle
9from flag_gems.utils.libentry import libentry
12@libentry()
13@triton.jit
14def simple_unique_flat_kernel(
15 sorted_data_ptr: tl.tensor,
16 sorted_indices_ptr: tl.tensor, # in
17 data_out_ptr: tl.tensor,
18 inverse_indices_ptr: tl.tensor,
19 idx_ptr: tl.tensor,
20 unique_size_ptr: tl.tensor, # out
21 return_inverse: tl.constexpr,
22 return_counts: tl.constexpr,
23 num_tasks: int,
24 tile_size: tl.constexpr,
25):
26 i0 = tl.arange(0, tile_size)
27 mask = i0 < num_tasks
29 # load
30 a = tl.load(sorted_data_ptr + i0, mask=mask)
31 i0_prev = tl.where(i0 > 0, i0 - 1, 0)
32 b = tl.load(sorted_data_ptr + i0_prev, mask=mask)
34 # ne & cumsum
35 ne_result = tl.where(i0 > 0, a != b, 0)
36 cumsum = tl.cumsum(ne_result)
38 # unique_size
39 unique_size_mask = i0 == tile_size - 1
40 unique_off = tl.where(unique_size_mask, tl.zeros_like(i0), -1)
41 tl.store(unique_size_ptr + unique_off, cumsum, mask=unique_size_mask)
43 # data_out: scatter_(to=cumsum, sorted_data)
44 data_out_off = tl.where(mask, cumsum, -1)
45 tl.store(data_out_ptr + data_out_off, a, mask=mask)
47 # inverse_indices: scatter_(to=sorted_indices, cumsum)
48 if return_inverse:
49 sorted_indices = tl.load(sorted_indices_ptr + i0, mask=mask)
50 tl.store(inverse_indices_ptr + sorted_indices, cumsum, mask=mask)
52 # idx
53 if return_counts:
54 idx_mask = ((i0 == 0) | ne_result.to(tl.int1)) & mask
55 tl.store(idx_ptr + cumsum, i0, mask=idx_mask)
58@triton.jit
59def output_counts_flat_impl(
60 global_pid,
61 idx_ptr: tl.tensor,
62 origin_num_tasks: int, # in
63 counts_ptr: tl.tensor, # out
64 num_tasks: int,
65 tile_size: tl.constexpr,
66):
67 r = tl.arange(0, tile_size)
69 # load idx
70 i0 = global_pid * tile_size + r
71 mask = i0 < num_tasks
72 idx = tl.load(idx_ptr + i0, mask=mask)
74 # load idx_next
75 i0_next = i0 + 1
76 next_mask = i0_next < num_tasks
77 idx_next = tl.load(idx_ptr + i0_next, mask=next_mask)
79 # diff
80 counts = tl.where(i0_next < num_tasks, idx_next - idx, origin_num_tasks - idx)
82 # store counts
83 tl.store(counts_ptr + i0, counts, mask=mask)
86@libentry()
87@triton.jit
88def output_counts_flat_kernel(
89 idx_ptr: tl.tensor,
90 origin_num_tasks: int, # in
91 counts_ptr: tl.tensor, # out
92 num_tasks: int,
93 tiles_per_cta: int,
94 tile_size: tl.constexpr,
95):
96 pid = tle.program_id(0)
97 ctas_num = tle.num_programs(0)
98 # grid-stride-loop style kernel
99 for j in range(0, tiles_per_cta):
100 global_pid = pid + j * ctas_num
101 output_counts_flat_impl(
102 global_pid,
103 idx_ptr,
104 origin_num_tasks, # in
105 counts_ptr, # out
106 num_tasks,
107 tile_size,
108 )
111@triton.jit
112def quick_output_flat_impl(
113 global_pid,
114 sorted_data_ptr: tl.tensor,
115 idx_ptr: tl.tensor,
116 origin_num_tasks: int, # in
117 data_out_ptr: tl.tensor,
118 counts_ptr: tl.tensor, # out
119 num_tasks: int,
120 tile_size: tl.constexpr,
121):
122 r = tl.arange(0, tile_size)
124 # load idx
125 i0 = global_pid * tile_size + r
126 mask = i0 < num_tasks
127 idx = tl.load(idx_ptr + i0, mask=mask)
129 # load idx_next
130 i0_next = i0 + 1
131 next_mask = i0_next < num_tasks
132 idx_next = tl.load(idx_ptr + i0_next, mask=next_mask)
134 # diff
135 counts = tl.where(i0_next < num_tasks, idx_next - idx, origin_num_tasks - idx)
137 # store counts
138 tl.store(counts_ptr + i0, counts, mask=mask)
140 # data_out: gather(sorted_data, from=idx)
141 sorted_data = tl.load(sorted_data_ptr + idx, mask=mask)
142 tl.store(data_out_ptr + i0, sorted_data, mask=mask)
145@libentry()
146@triton.jit
147def quick_output_flat_kernel(
148 sorted_data_ptr: tl.tensor,
149 idx_ptr: tl.tensor,
150 origin_num_tasks: int, # in
151 data_out_ptr: tl.tensor,
152 counts_ptr: tl.tensor, # out
153 num_tasks: int,
154 tiles_per_cta: int,
155 tile_size: tl.constexpr,
156):
157 pid = tle.program_id(0)
158 ctas_num = tle.num_programs(0)
159 # grid-stride-loop style kernel
160 for j in range(0, tiles_per_cta):
161 global_pid = pid + j * ctas_num
162 quick_output_flat_impl(
163 global_pid,
164 sorted_data_ptr,
165 idx_ptr,
166 origin_num_tasks, # in
167 data_out_ptr,
168 counts_ptr, # out
169 num_tasks,
170 tile_size,
171 )
174@triton.jit
175def local_quick_unique_flat_impl(
176 global_pid,
177 sorted_data_ptr: tl.tensor, # in
178 local_unique_ptr: tl.tensor,
179 origin_idx_ptr: tl.tensor,
180 tile_sum_ptr: tl.tensor, # out
181 global_ctas_num: int,
182 num_tasks: int,
183 tile_size: tl.constexpr,
184 return_counts: tl.constexpr,
185):
186 offset = global_pid * tile_size
187 r = tl.arange(0, tile_size)
188 i0 = offset + r
189 mask = i0 < num_tasks
191 # load
192 a = tl.load(sorted_data_ptr + i0, mask=mask)
193 i0_prev = tl.where(i0 > 0, i0 - 1, 0)
194 b = tl.load(sorted_data_ptr + i0_prev, mask=mask)
196 # ne & cumsum
197 ne_result = tl.where(i0 > 0, a != b, 0)
198 cumsum = tl.cumsum(ne_result)
200 # local_id or local_unique
201 local_unique_offset = cumsum - tl.where(global_pid > 0, 1, 0)
202 local_unique_mask = (local_unique_offset >= 0) & mask
203 if return_counts:
204 # origin_idx: scatter_(to=cumsum, i0)
205 origin_idx_mask = ((i0 == 0) | ne_result.to(tl.int1)) & local_unique_mask
206 lu_store_offset = offset + local_unique_offset
207 lu_store_offset = tl.where(origin_idx_mask, lu_store_offset, -1)
208 tl.store(
209 origin_idx_ptr + lu_store_offset,
210 i0,
211 mask=origin_idx_mask,
212 )
213 else:
214 # local_unique: scatter_(to=cumsum, sorted_data)
215 lu_store_offset = offset + local_unique_offset
216 lu_store_offset = tl.where(local_unique_mask, lu_store_offset, -1)
217 tl.store(local_unique_ptr + lu_store_offset, a, mask=local_unique_mask)
219 # tile_sum
220 tile_sum_mask = (r == tile_size - 1) & (global_pid < global_ctas_num)
221 tile_sum = tl.where(tile_sum_mask & (global_pid == 0), cumsum + 1, cumsum)
222 tile_sum_store_offset = global_pid + tl.zeros_like(r)
223 tile_sum_store_offset = tl.where(tile_sum_mask, tile_sum_store_offset, -1)
224 tl.store(tile_sum_ptr + tile_sum_store_offset, tile_sum, mask=tile_sum_mask)
227@libentry()
228@triton.jit
229def local_quick_unique_flat_kernel(
230 sorted_data_ptr: tl.tensor, # in
231 local_unique_ptr: tl.tensor,
232 origin_idx_ptr: tl.tensor,
233 tile_sum_ptr: tl.tensor, # out
234 global_ctas_num: int,
235 num_tasks: int,
236 tiles_per_cta: int,
237 tile_size: tl.constexpr,
238 return_counts: tl.constexpr,
239):
240 pid = tle.program_id(0)
241 ctas_num = tle.num_programs(0)
242 # grid-stride-loop style kernel
243 for j in range(0, tiles_per_cta):
244 global_pid = pid + j * ctas_num
245 local_quick_unique_flat_impl(
246 global_pid,
247 sorted_data_ptr, # in
248 local_unique_ptr,
249 origin_idx_ptr,
250 tile_sum_ptr, # out
251 global_ctas_num,
252 num_tasks,
253 tile_size,
254 return_counts,
255 )
258@triton.jit
259def global_quick_unique_flat_impl(
260 global_pid,
261 total,
262 local_unique_ptr: tl.tensor,
263 origin_idx_ptr: tl.tensor,
264 tile_sum_ptr: tl.tensor, # in
265 data_out_ptr: tl.tensor,
266 idx_ptr: tl.tensor, # out
267 ctas_num: tl.constexpr,
268 global_ctas_num: tl.constexpr,
269 next_power_global_ctas_num: tl.constexpr,
270 num_tasks: tl.constexpr,
271 tile_size: tl.constexpr,
272 return_counts: tl.constexpr,
273):
274 r = tl.arange(0, tile_size)
275 i0 = global_pid * tile_size + r
276 mask = i0 < num_tasks
278 # load tile_sum
279 p = tl.arange(0, next_power_global_ctas_num)
280 pre_tile_sum_mask = (
281 (p >= global_pid - ctas_num)
282 & (p < global_pid)
283 & (p >= 0)
284 & (p < global_ctas_num)
285 )
286 pre_tile_sum = tl.load(tile_sum_ptr + p, mask=pre_tile_sum_mask, other=0)
287 cur_tile_sum_mask = global_pid < global_ctas_num
288 cur_tile_sum = tl.load(tile_sum_ptr + global_pid, mask=cur_tile_sum_mask)
290 # total
291 total += tl.sum(pre_tile_sum)
292 if global_pid == global_ctas_num - 1:
293 last_tile_sum_mask = p == global_pid
294 tile_offset = tl.where(last_tile_sum_mask, p, -1)
295 tl.store(
296 tile_sum_ptr + tile_offset, total + cur_tile_sum, mask=last_tile_sum_mask
297 )
299 # idx or data_out
300 tile_mask = r < cur_tile_sum
301 out_offset = total + r
302 if return_counts:
303 # move origin_idx to idx_ptr
304 origin_idx = tl.load(origin_idx_ptr + i0, mask=mask)
305 idx_offset = tl.where(tile_mask, out_offset, -1)
306 tl.store(idx_ptr + idx_offset, origin_idx, mask=tile_mask)
307 else:
308 # move local_unique to data_out_ptr
309 local_unique = tl.load(local_unique_ptr + i0, mask=mask)
310 data_out_offset = tl.where(tile_mask, out_offset, -1)
311 tl.store(data_out_ptr + data_out_offset, local_unique, mask=tile_mask)
313 return total
316@libentry()
317@triton.jit
318def global_quick_unique_flat_kernel(
319 local_unique_ptr: tl.tensor,
320 origin_idx_ptr: tl.tensor,
321 tile_sum_ptr: tl.tensor, # in
322 data_out_ptr: tl.tensor,
323 idx_ptr: tl.tensor, # out
324 ctas_num: tl.constexpr,
325 global_ctas_num: tl.constexpr,
326 next_power_global_ctas_num: tl.constexpr,
327 num_tasks: tl.constexpr,
328 tiles_per_cta: tl.constexpr,
329 tile_size: tl.constexpr,
330 one_tile_per_cta: tl.constexpr,
331 return_counts: tl.constexpr,
332):
333 pid = tle.program_id(0)
334 ctas_num = tle.num_programs(0)
335 if one_tile_per_cta: # monolitic kernel style
336 global_quick_unique_flat_impl(
337 pid,
338 0,
339 local_unique_ptr,
340 origin_idx_ptr,
341 tile_sum_ptr, # in
342 data_out_ptr,
343 idx_ptr, # out
344 ctas_num,
345 global_ctas_num,
346 next_power_global_ctas_num,
347 num_tasks,
348 tile_size,
349 return_counts,
350 )
351 else: # grid-stride-loop style kernel
352 total = tl.zeros([1], dtype=tl.int64)
353 for j in range(0, tiles_per_cta):
354 global_pid = pid + j * ctas_num
355 total = global_quick_unique_flat_impl(
356 global_pid,
357 total,
358 local_unique_ptr,
359 origin_idx_ptr,
360 tile_sum_ptr, # in
361 data_out_ptr,
362 idx_ptr, # out
363 ctas_num,
364 global_ctas_num,
365 next_power_global_ctas_num,
366 num_tasks,
367 tile_size,
368 return_counts,
369 )
372@triton.jit
373def global_quick_unique_flat_impl_stage_1(
374 global_pid,
375 total,
376 local_unique_ptr: tl.tensor,
377 origin_idx_ptr: tl.tensor,
378 tile_sum_ptr: tl.tensor, # in
379 data_out_ptr: tl.tensor,
380 idx_ptr: tl.tensor, # out
381 ctas_num: tl.constexpr,
382 global_ctas_num: tl.constexpr,
383 next_power_global_ctas_num: tl.constexpr,
384 num_tasks: tl.constexpr,
385 tile_size: tl.constexpr,
386 return_counts: tl.constexpr,
387):
388 # r = tl.arange(0, tile_size)
389 # i0 = global_pid * tile_size + r
390 # mask = i0 < num_tasks
392 # load tile_sum
393 p = tl.arange(0, next_power_global_ctas_num)
394 pre_tile_sum_mask = (
395 (p >= global_pid - ctas_num)
396 & (p < global_pid)
397 & (p >= 0)
398 & (p < global_ctas_num)
399 )
400 pre_tile_sum = tl.load(tile_sum_ptr + p, mask=pre_tile_sum_mask, other=0)
401 cur_tile_sum_mask = global_pid < global_ctas_num
402 cur_tile_sum = tl.load(tile_sum_ptr + global_pid, mask=cur_tile_sum_mask)
404 # total
405 total += tl.sum(pre_tile_sum)
406 if global_pid == global_ctas_num - 1:
407 last_tile_sum_mask = p == global_pid
408 tile_offset = tl.where(last_tile_sum_mask, p, -1)
409 tl.store(
410 tile_sum_ptr + tile_offset, total + cur_tile_sum, mask=last_tile_sum_mask
411 )
413 return total
416@libentry()
417@triton.jit
418def global_quick_unique_flat_kernel_stage_1(
419 local_unique_ptr: tl.tensor,
420 origin_idx_ptr: tl.tensor,
421 tile_sum_ptr: tl.tensor, # in
422 data_out_ptr: tl.tensor,
423 idx_ptr: tl.tensor, # out
424 ctas_num: tl.constexpr,
425 global_ctas_num: tl.constexpr,
426 next_power_global_ctas_num: tl.constexpr,
427 num_tasks: tl.constexpr,
428 tiles_per_cta: tl.constexpr,
429 tile_size: tl.constexpr,
430 one_tile_per_cta: tl.constexpr,
431 return_counts: tl.constexpr,
432):
433 pid = tle.program_id(0)
434 ctas_num = tle.num_programs(0)
435 if one_tile_per_cta: # monolitic kernel style
436 global_quick_unique_flat_impl_stage_1(
437 pid,
438 0,
439 local_unique_ptr,
440 origin_idx_ptr,
441 tile_sum_ptr, # in
442 data_out_ptr,
443 idx_ptr, # out
444 ctas_num,
445 global_ctas_num,
446 next_power_global_ctas_num,
447 num_tasks,
448 tile_size,
449 return_counts,
450 )
451 else: # grid-stride-loop style kernel
452 total = tl.zeros([1], dtype=tl.int64)
453 for j in range(0, tiles_per_cta):
454 global_pid = pid + j * ctas_num
455 total = global_quick_unique_flat_impl_stage_1(
456 global_pid,
457 total,
458 local_unique_ptr,
459 origin_idx_ptr,
460 tile_sum_ptr, # in
461 data_out_ptr,
462 idx_ptr, # out
463 ctas_num,
464 global_ctas_num,
465 next_power_global_ctas_num,
466 num_tasks,
467 tile_size,
468 return_counts,
469 )
472@triton.jit
473def global_quick_unique_flat_impl_stage_2(
474 global_pid,
475 total,
476 local_unique_ptr: tl.tensor,
477 origin_idx_ptr: tl.tensor,
478 tile_sum_ptr: tl.tensor, # in
479 data_out_ptr: tl.tensor,
480 idx_ptr: tl.tensor, # out
481 total_in_ptr,
482 ctas_num: tl.constexpr,
483 global_ctas_num: tl.constexpr,
484 next_power_global_ctas_num: tl.constexpr,
485 num_tasks: tl.constexpr,
486 tile_size: tl.constexpr,
487 return_counts: tl.constexpr,
488):
489 r = tl.arange(0, tile_size)
490 i0 = global_pid * tile_size + r
491 mask = i0 < num_tasks
493 # load tile_sum
494 # p = tl.arange(0, next_power_global_ctas_num)
495 # pre_tile_sum_mask = (
496 # (p >= global_pid - ctas_num)
497 # & (p < global_pid)
498 # & (p >= 0)
499 # & (p < global_ctas_num)
500 # )
501 # pre_tile_sum = tl.load(tile_sum_ptr + p, mask=pre_tile_sum_mask, other=0)
502 cur_tile_sum_mask = global_pid < global_ctas_num
503 cur_tile_sum = tl.load(tile_sum_ptr + global_pid, mask=cur_tile_sum_mask)
505 # total
506 total_in_mask = global_pid < global_ctas_num
507 total = tl.load(total_in_ptr + global_pid, mask=total_in_mask)
508 # tl.device_print("total", total)
510 # idx or data_out
511 tile_mask = r < cur_tile_sum
512 out_offset = total + r
513 if return_counts:
514 # move origin_idx to idx_ptr
515 origin_idx = tl.load(origin_idx_ptr + i0, mask=mask)
516 idx_offset = tl.where(tile_mask, out_offset, -1)
517 tl.store(idx_ptr + idx_offset, origin_idx, mask=tile_mask)
518 else:
519 # move local_unique to data_out_ptr
520 local_unique = tl.load(local_unique_ptr + i0, mask=mask)
521 data_out_offset = tl.where(tile_mask, out_offset, -1)
522 tl.store(data_out_ptr + data_out_offset, local_unique, mask=tile_mask)
524 return total
527@libentry()
528@triton.jit
529def global_quick_unique_flat_kernel_stage_2(
530 local_unique_ptr: tl.tensor,
531 origin_idx_ptr: tl.tensor,
532 tile_sum_ptr: tl.tensor, # in
533 data_out_ptr: tl.tensor,
534 idx_ptr: tl.tensor, # out
535 total_in_ptr,
536 ctas_num: tl.constexpr,
537 global_ctas_num: tl.constexpr,
538 next_power_global_ctas_num: tl.constexpr,
539 num_tasks: tl.constexpr,
540 tiles_per_cta: tl.constexpr,
541 tile_size: tl.constexpr,
542 one_tile_per_cta: tl.constexpr,
543 return_counts: tl.constexpr,
544):
545 pid = tle.program_id(0)
546 ctas_num = tle.num_programs(0)
547 if one_tile_per_cta: # monolitic kernel style
548 global_quick_unique_flat_impl_stage_2(
549 pid,
550 0,
551 local_unique_ptr,
552 origin_idx_ptr,
553 tile_sum_ptr, # in
554 data_out_ptr,
555 idx_ptr, # out
556 total_in_ptr,
557 ctas_num,
558 global_ctas_num,
559 next_power_global_ctas_num,
560 num_tasks,
561 tile_size,
562 return_counts,
563 )
564 else: # grid-stride-loop style kernel
565 total = tl.zeros([1], dtype=tl.int64)
566 for j in range(0, tiles_per_cta):
567 global_pid = pid + j * ctas_num
568 total = global_quick_unique_flat_impl_stage_2(
569 global_pid,
570 total,
571 local_unique_ptr,
572 origin_idx_ptr,
573 tile_sum_ptr, # in
574 data_out_ptr,
575 idx_ptr, # out
576 total_in_ptr,
577 ctas_num,
578 global_ctas_num,
579 next_power_global_ctas_num,
580 num_tasks,
581 tile_size,
582 return_counts,
583 )
586def sorted_quick_unique_flat(sorted_data: torch.Tensor, return_counts: bool):
587 num_tasks = sorted_data.numel()
588 next_power_num_tasks = triton.next_power_of_2(num_tasks)
589 tile_size = min(8192, next_power_num_tasks)
590 global_ctas_num = triton.cdiv(num_tasks, tile_size)
591 # if global_ctas_num <= 8192:
592 # tile_size = max(
593 # 32, min(triton.next_power_of_2(global_ctas_num), next_power_num_tasks)
594 # )
595 # global_ctas_num = triton.cdiv(num_tasks, tile_size)
596 next_power_global_ctas_num = triton.next_power_of_2(global_ctas_num)
597 ctas_num = global_ctas_num # if global_ctas_num < 65536 else 2048
598 tiles_per_cta = triton.cdiv(num_tasks, tile_size * ctas_num)
599 num_warps = 8 if tiles_per_cta == 1 else 32
600 grid = (ctas_num, 1, 1)
601 # print(f"ctas_num = {ctas_num}")
602 # print(f"tile_size = {tile_size}")
603 # print(f"global_ctas_num = {global_ctas_num}")
604 # print(f"tiles_per_cta = {tiles_per_cta}")
606 # allocate tensor
607 if return_counts:
608 local_unique = None
609 origin_idx = torch.empty_like(sorted_data, dtype=torch.int64)
610 idx = torch.empty_like(origin_idx)
611 else:
612 local_unique = torch.empty_like(sorted_data)
613 origin_idx = None
614 idx = None
615 counts = None
616 tile_sum = torch.empty(
617 (global_ctas_num,), dtype=torch.int64, device=sorted_data.device
618 )
619 data_out = None
620 if not return_counts:
621 data_out = torch.empty_like(sorted_data)
622 assert tiles_per_cta == 1
623 # launch kernel
624 with torch_device_fn.device(sorted_data.device.index):
625 os.environ["TRITONXPU_OTHER_SIM"] = "1"
626 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1"
627 local_quick_unique_flat_kernel[grid](
628 sorted_data, # in
629 local_unique,
630 origin_idx,
631 tile_sum, # out
632 global_ctas_num,
633 num_tasks,
634 tiles_per_cta=tiles_per_cta,
635 tile_size=tile_size,
636 return_counts=return_counts,
637 num_warps=num_warps,
638 )
639 if "TRITONXPU_OTHER_SIM" in os.environ:
640 del os.environ["TRITONXPU_OTHER_SIM"]
641 if "TRITONXPU_STORE_MASK_SIM" in os.environ:
642 del os.environ["TRITONXPU_STORE_MASK_SIM"]
644 if num_tasks < 2**26:
645 os.environ["TRITONXPU_OTHER_SIM"] = "1"
646 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1"
647 os.environ["TRITONXPU_INTERLEAVE"] = "0"
648 global_quick_unique_flat_kernel[grid](
649 local_unique,
650 origin_idx,
651 tile_sum, # in
652 data_out,
653 idx, # out
654 ctas_num,
655 global_ctas_num,
656 next_power_global_ctas_num,
657 num_tasks,
658 tiles_per_cta=tiles_per_cta,
659 tile_size=tile_size,
660 one_tile_per_cta=tiles_per_cta == 1,
661 return_counts=return_counts,
662 num_warps=num_warps,
663 isCloseVectorization=True,
664 # buffer_size_limit=128,
665 )
666 if "TRITONXPU_OTHER_SIM" in os.environ:
667 del os.environ["TRITONXPU_OTHER_SIM"]
668 if "TRITONXPU_STORE_MASK_SIM" in os.environ:
669 del os.environ["TRITONXPU_STORE_MASK_SIM"]
670 if "TRITONXPU_INTERLEAVE" in os.environ:
671 del os.environ["TRITONXPU_INTERLEAVE"]
672 else:
673 # print(f'tile_sum.shape = {tile_sum.shape}')
674 # print(f'tile_sum.cpu() = {tile_sum.cpu()}')
675 total_in = torch.cumsum(tile_sum, dim=0)
676 total_in = torch.roll(total_in, shifts=1)
677 total_in[0] = 0
678 # print(f'in total_in.cpu() = {total_in.cpu()}')
680 os.environ["TRITONXPU_OTHER_SIM"] = "1"
681 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1"
682 os.environ["TRITONXPU_INTERLEAVE"] = "0"
683 global_quick_unique_flat_kernel_stage_1[grid](
684 local_unique,
685 origin_idx,
686 tile_sum, # in
687 data_out,
688 idx, # out
689 ctas_num,
690 global_ctas_num,
691 next_power_global_ctas_num,
692 num_tasks,
693 tiles_per_cta=tiles_per_cta,
694 tile_size=tile_size,
695 one_tile_per_cta=tiles_per_cta == 1,
696 return_counts=return_counts,
697 num_warps=num_warps,
698 isCloseVectorization=True,
699 buffer_size_limit=128,
700 )
701 if "TRITONXPU_OTHER_SIM" in os.environ:
702 del os.environ["TRITONXPU_OTHER_SIM"]
703 if "TRITONXPU_STORE_MASK_SIM" in os.environ:
704 del os.environ["TRITONXPU_STORE_MASK_SIM"]
705 if "TRITONXPU_INTERLEAVE" in os.environ:
706 del os.environ["TRITONXPU_INTERLEAVE"]
708 os.environ["TRITONXPU_OTHER_SIM"] = "1"
709 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1"
710 os.environ["TRITONXPU_INTERLEAVE"] = "0"
711 global_quick_unique_flat_kernel_stage_2[grid](
712 local_unique,
713 origin_idx,
714 tile_sum, # in
715 data_out,
716 idx, # out
717 total_in,
718 ctas_num,
719 global_ctas_num,
720 next_power_global_ctas_num,
721 num_tasks,
722 tiles_per_cta=tiles_per_cta,
723 tile_size=tile_size,
724 one_tile_per_cta=tiles_per_cta == 1,
725 return_counts=return_counts,
726 num_warps=num_warps,
727 isCloseVectorization=True,
728 )
729 if "TRITONXPU_OTHER_SIM" in os.environ:
730 del os.environ["TRITONXPU_OTHER_SIM"]
731 if "TRITONXPU_STORE_MASK_SIM" in os.environ:
732 del os.environ["TRITONXPU_STORE_MASK_SIM"]
733 if "TRITONXPU_INTERLEAVE" in os.environ:
734 del os.environ["TRITONXPU_INTERLEAVE"]
736 out_size = tile_sum[-1].item()
737 if return_counts:
738 data_out = torch.empty(
739 (out_size,), dtype=sorted_data.dtype, device=sorted_data.device
740 )
741 idx = idx[:out_size]
742 counts = origin_idx[:out_size]
743 quick_output_flat_kernel[grid](
744 sorted_data,
745 idx,
746 num_tasks, # in
747 data_out,
748 counts, # out
749 out_size,
750 tiles_per_cta,
751 tile_size,
752 num_warps=num_warps,
753 isCloseUnrollControl=True
754 if sorted_data.dtype == torch.int16
755 else False,
756 )
758 if return_counts:
759 return data_out, None, counts
760 else:
761 return data_out[:out_size], None, None
764@triton.jit
765def local_ne_flat_impl(
766 global_pid,
767 sorted_data_ptr: tl.tensor, # in
768 ne_result_ptr: tl.tensor,
769 tile_sum_ptr: tl.tensor, # out
770 global_ctas_num: int,
771 num_tasks: int,
772 tile_size: tl.constexpr,
773):
774 r = tl.arange(0, tile_size)
775 i0 = global_pid * tile_size + r
776 mask = i0 < num_tasks
777 i0_prev = tl.where(i0 > 0, i0 - 1, 0)
779 # load
780 a = tl.load(sorted_data_ptr + i0, mask=mask)
781 b = tl.load(sorted_data_ptr + i0_prev, mask=mask)
783 # compute
784 ne_result = tl.where(i0 > 0, a != b, 0)
786 # store ne_result
787 tl.store(ne_result_ptr + i0, ne_result, mask=mask)
789 # store tile_sum
790 tile_sum = tl.sum(ne_result)
791 tile_sum_mask = global_pid < global_ctas_num
792 tl.store(tile_sum_ptr + global_pid, tile_sum, mask=tile_sum_mask)
795@libentry()
796@triton.jit
797def local_ne_flat_kernel(
798 sorted_data_ptr: tl.tensor, # in
799 ne_result_ptr: tl.tensor,
800 tile_sum_ptr: tl.tensor, # out
801 global_ctas_num: int,
802 num_tasks: int,
803 tiles_per_cta: int,
804 tile_size: tl.constexpr,
805):
806 pid = tle.program_id(0)
807 ctas_num = tle.num_programs(0)
808 # grid-stride-loop style kernel
809 for j in range(0, tiles_per_cta):
810 global_pid = pid + j * ctas_num
811 local_ne_flat_impl(
812 global_pid,
813 sorted_data_ptr, # in
814 ne_result_ptr,
815 tile_sum_ptr, # out
816 global_ctas_num,
817 num_tasks,
818 tile_size,
819 )
822@triton.jit
823def global_cumsum_flat_impl(
824 global_pid,
825 total,
826 ne_result_ptr: tl.tensor,
827 tile_sum_ptr: tl.tensor, # in
828 sorted_data_ptr: tl.tensor,
829 sorted_indices_ptr: tl.tensor, # in
830 data_out_ptr: tl.tensor,
831 inverse_indices_ptr: tl.tensor,
832 idx_ptr: tl.tensor, # out
833 cumsum_out,
834 ctas_num: tl.constexpr,
835 global_ctas_num: int,
836 next_power_global_ctas_num: tl.constexpr,
837 num_tasks: int,
838 tile_size: tl.constexpr,
839 return_counts: tl.constexpr,
840):
841 offset = global_pid * tile_size
842 r = tl.arange(0, tile_size)
843 i0 = offset + r
844 mask = i0 < num_tasks
846 # load sorted_data, sorted_indices
847 sorted_data = tl.load(sorted_data_ptr + i0, mask=mask)
848 sorted_indices = tl.load(sorted_indices_ptr + i0, mask=mask)
850 # load tile_sum
851 p = tl.arange(0, next_power_global_ctas_num)
852 pre_tile_sum_mask = (
853 (p >= global_pid - ctas_num)
854 & (p < global_pid)
855 & (p >= 0)
856 & (p < global_ctas_num)
857 )
858 pre_tile_sum = tl.load(tile_sum_ptr + p, mask=pre_tile_sum_mask, other=0)
860 # cumsum
861 total += tl.sum(pre_tile_sum)
862 # tl.device_print("total", total)
863 ne_result = tl.load(ne_result_ptr + i0, mask=mask)
864 ne_result_i1 = ne_result.to(tl.int1)
865 ne_result = ne_result.to(tl.int32)
866 # tl.device_print("ne_result", ne_result)
867 cumsum = tl.cumsum(ne_result)
868 # tl.store(cumsum_out + i0, cumsum)
869 # tl.device_print("cumsum", cumsum)
871 # tile_sum
872 if global_pid == global_ctas_num - 1:
873 last_tile_sum_mask = i0 == num_tasks - 1
874 tile_sum = tl.where(last_tile_sum_mask, total + cumsum, cumsum)
875 tile_offset = tl.where(last_tile_sum_mask, global_pid + tl.zeros_like(r), -1)
876 tl.store(
877 tile_sum_ptr + tile_offset,
878 tile_sum,
879 mask=last_tile_sum_mask,
880 )
881 cumsum += total
883 # data_out: scatter_(to=cumsum, sorted_data)
884 tl.store(data_out_ptr + cumsum, sorted_data, mask=mask)
886 # inverse_indices: scatter_(to=sorted_indices, cumsum)
887 tl.store(inverse_indices_ptr + sorted_indices, cumsum, mask=mask)
889 # idx
890 if return_counts:
891 idx_mask = ((i0 == 0) | ne_result_i1) & mask
892 idx_offset = tl.where(idx_mask, cumsum, num_tasks + 1)
893 tl.store(idx_ptr + idx_offset, i0, mask=idx_mask)
895 return total
898@libentry()
899@triton.jit
900def global_cumsum_flat_kernel(
901 ne_result_ptr: tl.tensor,
902 tile_sum_ptr: tl.tensor, # in
903 sorted_data_ptr: tl.tensor,
904 sorted_indices_ptr: tl.tensor, # in
905 data_out_ptr: tl.tensor,
906 inverse_indices_ptr: tl.tensor,
907 idx_ptr: tl.tensor, # out
908 cumsum_out,
909 ctas_num: int,
910 global_ctas_num: int,
911 next_power_global_ctas_num: tl.constexpr,
912 num_tasks: int,
913 tiles_per_cta: int,
914 tile_size: tl.constexpr,
915 one_tile_per_cta: tl.constexpr,
916 return_counts: tl.constexpr,
917):
918 pid = tle.program_id(0)
919 ctas_num = tle.num_programs(0)
920 if one_tile_per_cta: # monolitic kernel style
921 global_cumsum_flat_impl(
922 pid,
923 0,
924 ne_result_ptr,
925 tile_sum_ptr, # in
926 sorted_data_ptr,
927 sorted_indices_ptr, # in
928 data_out_ptr,
929 inverse_indices_ptr,
930 idx_ptr, # out
931 cumsum_out,
932 ctas_num,
933 global_ctas_num,
934 next_power_global_ctas_num,
935 num_tasks,
936 tile_size,
937 return_counts,
938 )
939 else: # grid-stride-loop style kernel
940 total = tl.zeros([1], dtype=tl.int64)
941 for j in range(0, tiles_per_cta):
942 global_pid = pid + j * ctas_num
943 total = global_cumsum_flat_impl(
944 global_pid,
945 total,
946 ne_result_ptr,
947 tile_sum_ptr, # in
948 sorted_data_ptr,
949 sorted_indices_ptr, # in
950 data_out_ptr,
951 inverse_indices_ptr,
952 idx_ptr, # out
953 cumsum_out,
954 ctas_num,
955 global_ctas_num,
956 next_power_global_ctas_num,
957 num_tasks,
958 tile_size,
959 return_counts,
960 )
963@triton.jit
964def global_cumsum_flat_impl_stage_1(
965 global_pid,
966 total,
967 ne_result_ptr: tl.tensor,
968 tile_sum_ptr: tl.tensor, # in
969 sorted_data_ptr: tl.tensor,
970 sorted_indices_ptr: tl.tensor, # in
971 data_out_ptr: tl.tensor,
972 inverse_indices_ptr: tl.tensor,
973 idx_ptr: tl.tensor, # out
974 total_in_ptr,
975 cumsum_in_ptr,
976 ctas_num: tl.constexpr,
977 global_ctas_num: int,
978 next_power_global_ctas_num: tl.constexpr,
979 num_tasks: int,
980 tile_size: tl.constexpr,
981 return_counts: tl.constexpr,
982):
983 offset = global_pid * tile_size
984 r = tl.arange(0, tile_size)
985 i0 = offset + r
986 mask = i0 < num_tasks
988 # load sorted_data, sorted_indices
989 # sorted_data = tl.load(sorted_data_ptr + i0, mask=mask)
990 # sorted_indices = tl.load(sorted_indices_ptr + i0, mask=mask)
992 # load tile_sum
993 # p = tl.arange(0, next_power_global_ctas_num)
994 # pre_tile_sum_mask = (
995 # (p >= global_pid - ctas_num)
996 # & (p < global_pid)
997 # & (p >= 0)
998 # & (p < global_ctas_num)
999 # )
1000 # pre_tile_sum = tl.load(tile_sum_ptr + p, mask=pre_tile_sum_mask, other=0)
1002 # cumsum
1003 # total += tl.sum(pre_tile_sum)
1004 # ne_result = tl.load(ne_result_ptr + i0, mask=mask)
1005 # ne_result_i1 = ne_result.to(tl.int1)
1006 # ne_result = ne_result.to(tl.int32)
1007 # cumsum = tl.cumsum(ne_result)
1008 total_in_mask = global_pid < global_ctas_num
1009 total = tl.load(total_in_ptr + global_pid, mask=total_in_mask)
1011 ne_result = tl.load(ne_result_ptr + i0, mask=mask)
1012 # ne_result_i1 = ne_result.to(tl.int1)
1013 ne_result = ne_result.to(tl.int32)
1014 # tl.device_print("ne_result", ne_result)
1015 # cumsum = tl.cumsum(ne_result)
1016 cumsum = tl.load(cumsum_in_ptr + i0)
1018 # tile_sum
1019 if global_pid == global_ctas_num - 1:
1020 last_tile_sum_mask = i0 == num_tasks - 1
1021 tile_sum = tl.where(last_tile_sum_mask, total + cumsum, cumsum)
1022 tile_offset = tl.where(last_tile_sum_mask, global_pid + tl.zeros_like(r), -1)
1023 tl.store(
1024 tile_sum_ptr + tile_offset,
1025 tile_sum,
1026 mask=last_tile_sum_mask,
1027 )
1029 return total
1032@libentry()
1033@triton.jit
1034def global_cumsum_flat_kernel_stage_1(
1035 ne_result_ptr: tl.tensor,
1036 tile_sum_ptr: tl.tensor, # in
1037 sorted_data_ptr: tl.tensor,
1038 sorted_indices_ptr: tl.tensor, # in
1039 data_out_ptr: tl.tensor,
1040 inverse_indices_ptr: tl.tensor,
1041 idx_ptr: tl.tensor, # out
1042 total_in_ptr,
1043 cumsum_in_ptr,
1044 ctas_num: int,
1045 global_ctas_num: int,
1046 next_power_global_ctas_num: tl.constexpr,
1047 num_tasks: int,
1048 tiles_per_cta: int,
1049 tile_size: tl.constexpr,
1050 one_tile_per_cta: tl.constexpr,
1051 return_counts: tl.constexpr,
1052):
1053 pid = tle.program_id(0)
1054 ctas_num = tle.num_programs(0)
1055 if one_tile_per_cta: # monolitic kernel style
1056 global_cumsum_flat_impl_stage_1(
1057 pid,
1058 0,
1059 ne_result_ptr,
1060 tile_sum_ptr, # in
1061 sorted_data_ptr,
1062 sorted_indices_ptr, # in
1063 data_out_ptr,
1064 inverse_indices_ptr,
1065 idx_ptr, # out
1066 total_in_ptr,
1067 cumsum_in_ptr,
1068 ctas_num,
1069 global_ctas_num,
1070 next_power_global_ctas_num,
1071 num_tasks,
1072 tile_size,
1073 return_counts,
1074 )
1075 else: # grid-stride-loop style kernel
1076 total = tl.zeros([1], dtype=tl.int64)
1077 for j in range(0, tiles_per_cta):
1078 global_pid = pid + j * ctas_num
1079 total = global_cumsum_flat_impl_stage_1(
1080 global_pid,
1081 total,
1082 ne_result_ptr,
1083 tile_sum_ptr, # in
1084 sorted_data_ptr,
1085 sorted_indices_ptr, # in
1086 data_out_ptr,
1087 inverse_indices_ptr,
1088 idx_ptr, # out
1089 total_in_ptr,
1090 cumsum_in_ptr,
1091 ctas_num,
1092 global_ctas_num,
1093 next_power_global_ctas_num,
1094 num_tasks,
1095 tile_size,
1096 return_counts,
1097 )
1100@triton.jit
1101def global_cumsum_flat_impl_stage_2(
1102 global_pid,
1103 total,
1104 ne_result_ptr: tl.tensor,
1105 tile_sum_ptr: tl.tensor, # in
1106 sorted_data_ptr: tl.tensor,
1107 sorted_indices_ptr: tl.tensor, # in
1108 data_out_ptr: tl.tensor,
1109 inverse_indices_ptr: tl.tensor,
1110 idx_ptr: tl.tensor, # out
1111 total_in_ptr,
1112 cumsum_in_ptr,
1113 ctas_num: tl.constexpr,
1114 global_ctas_num: int,
1115 next_power_global_ctas_num: tl.constexpr,
1116 num_tasks: int,
1117 tile_size: tl.constexpr,
1118 return_counts: tl.constexpr,
1119):
1120 offset = global_pid * tile_size
1121 r = tl.arange(0, tile_size)
1122 i0 = offset + r
1123 mask = i0 < num_tasks
1125 # load sorted_data, sorted_indices
1126 sorted_data = tl.load(sorted_data_ptr + i0, mask=mask)
1127 sorted_indices = tl.load(sorted_indices_ptr + i0, mask=mask)
1129 # load tile_sum
1130 # p = tl.arange(0, next_power_global_ctas_num)
1131 # pre_tile_sum_mask = (
1132 # (p >= global_pid - ctas_num)
1133 # & (p < global_pid)
1134 # & (p >= 0)
1135 # & (p < global_ctas_num)
1136 # )
1137 # pre_tile_sum = tl.load(tile_sum_ptr + p, mask=pre_tile_sum_mask, other=0)
1139 # cumsum
1140 total_in_mask = global_pid < global_ctas_num
1141 total = tl.load(total_in_ptr + global_pid, mask=total_in_mask)
1143 ne_result = tl.load(ne_result_ptr + i0, mask=mask)
1144 ne_result_i1 = ne_result.to(tl.int1)
1145 ne_result = ne_result.to(tl.int32)
1146 # tl.device_print("ne_result", ne_result)
1147 # cumsum = tl.cumsum(ne_result)
1148 cumsum = tl.load(cumsum_in_ptr + i0)
1149 # tl.device_print("cumsum", cumsum)
1150 cumsum += total
1152 # data_out: scatter_(to=cumsum, sorted_data)
1153 tl.store(data_out_ptr + cumsum, sorted_data, mask=mask)
1155 # inverse_indices: scatter_(to=sorted_indices, cumsum)
1156 tl.store(inverse_indices_ptr + sorted_indices, cumsum, mask=mask)
1158 # idx
1159 if return_counts:
1160 idx_mask = ((i0 == 0) | ne_result_i1) & mask
1161 idx_offset = tl.where(idx_mask, cumsum, num_tasks + 1)
1162 tl.store(idx_ptr + idx_offset, i0, mask=idx_mask)
1164 return total
1167@libentry()
1168@triton.jit
1169def global_cumsum_flat_kernel_stage_2(
1170 ne_result_ptr: tl.tensor,
1171 tile_sum_ptr: tl.tensor, # in
1172 sorted_data_ptr: tl.tensor,
1173 sorted_indices_ptr: tl.tensor, # in
1174 data_out_ptr: tl.tensor,
1175 inverse_indices_ptr: tl.tensor,
1176 idx_ptr: tl.tensor, # out
1177 total_in_ptr,
1178 cumsum_in_ptr,
1179 ctas_num: int,
1180 global_ctas_num: int,
1181 next_power_global_ctas_num: tl.constexpr,
1182 num_tasks: int,
1183 tiles_per_cta: int,
1184 tile_size: tl.constexpr,
1185 one_tile_per_cta: tl.constexpr,
1186 return_counts: tl.constexpr,
1187):
1188 pid = tle.program_id(0)
1189 ctas_num = tle.num_programs(0)
1190 if one_tile_per_cta: # monolitic kernel style
1191 global_cumsum_flat_impl_stage_2(
1192 pid,
1193 0,
1194 ne_result_ptr,
1195 tile_sum_ptr, # in
1196 sorted_data_ptr,
1197 sorted_indices_ptr, # in
1198 data_out_ptr,
1199 inverse_indices_ptr,
1200 idx_ptr, # out
1201 total_in_ptr,
1202 cumsum_in_ptr,
1203 ctas_num,
1204 global_ctas_num,
1205 next_power_global_ctas_num,
1206 num_tasks,
1207 tile_size,
1208 return_counts,
1209 )
1210 else: # grid-stride-loop style kernel
1211 total = tl.zeros([1], dtype=tl.int64)
1212 for j in range(0, tiles_per_cta):
1213 global_pid = pid + j * ctas_num
1214 total = global_cumsum_flat_impl_stage_2(
1215 global_pid,
1216 total,
1217 ne_result_ptr,
1218 tile_sum_ptr, # in
1219 sorted_data_ptr,
1220 sorted_indices_ptr, # in
1221 data_out_ptr,
1222 inverse_indices_ptr,
1223 idx_ptr, # out
1224 total_in_ptr,
1225 cumsum_in_ptr,
1226 ctas_num,
1227 global_ctas_num,
1228 next_power_global_ctas_num,
1229 num_tasks,
1230 tile_size,
1231 return_counts,
1232 )
1235def sorted_indices_unique_flat(
1236 sorted_data: torch.Tensor, sorted_indices: torch.Tensor, return_counts: bool
1237):
1238 num_tasks = sorted_data.numel()
1239 next_power_num_tasks = triton.next_power_of_2(num_tasks)
1240 tile_size = min(2048, next_power_num_tasks)
1241 global_ctas_num = triton.cdiv(num_tasks, tile_size)
1242 # if global_ctas_num <= 8192:
1243 # min_tile_size = 512 if global_ctas_num > 32 else 256
1244 # tile_size = max(
1245 # min_tile_size,
1246 # min(triton.next_power_of_2(global_ctas_num), next_power_num_tasks),
1247 # )
1248 # global_ctas_num = triton.cdiv(num_tasks, tile_size)
1249 next_power_global_ctas_num = triton.next_power_of_2(global_ctas_num)
1250 ctas_num = global_ctas_num # if global_ctas_num < 32768 else 8192
1251 tiles_per_cta = triton.cdiv(num_tasks, tile_size * ctas_num)
1252 num_warps = 8 if tiles_per_cta == 1 else 32
1253 grid = (ctas_num, 1, 1)
1254 # print(f"ctas_num = {ctas_num}")
1255 # print(f"tile_size = {tile_size}")
1256 # print(f"tiles_per_cta = {tiles_per_cta}")
1257 # print(f"global_ctas_num = {global_ctas_num}")
1259 # allocate tensor
1260 ne_result = torch.empty_like(sorted_data, dtype=torch.bool)
1261 tile_sum = torch.empty(
1262 (global_ctas_num,), dtype=torch.int64, device=sorted_data.device
1263 )
1264 data_out = torch.empty_like(sorted_data)
1265 inverse_indices = torch.empty_like(sorted_data, dtype=torch.int64)
1266 idx = None
1267 if return_counts:
1268 idx = torch.empty_like(inverse_indices)
1270 # assert tiles_per_cta == 1
1272 # launch kernel
1273 with torch_device_fn.device(sorted_data.device.index):
1274 os.environ["TRITONXPU_OTHER_SIM"] = "1"
1275 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1"
1276 os.environ["TRITONXPU_INTERLEAVE"] = "0"
1278 local_ne_flat_kernel[grid](
1279 sorted_data, # in
1280 ne_result,
1281 tile_sum, # out
1282 global_ctas_num,
1283 num_tasks,
1284 tiles_per_cta=tiles_per_cta,
1285 tile_size=tile_size,
1286 num_warps=num_warps,
1287 )
1288 if "TRITONXPU_OTHER_SIM" in os.environ:
1289 del os.environ["TRITONXPU_OTHER_SIM"]
1290 if "TRITONXPU_STORE_MASK_SIM" in os.environ:
1291 del os.environ["TRITONXPU_STORE_MASK_SIM"]
1292 if "TRITONXPU_INTERLEAVE" in os.environ:
1293 del os.environ["TRITONXPU_INTERLEAVE"]
1295 if num_tasks < 2**26:
1296 # print(f"ne_result.shape = {ne_result.shape}")
1297 # print(f"tile_sum.shape = {tile_sum.shape}")
1298 # print(f'tile_sum.cpu() = {tile_sum.cpu()}')
1299 next_multiple = ((num_tasks // 2048) + 1) * 2048
1300 cumsum_out = torch.zeros(next_multiple)
1301 os.environ["TRITONXPU_OTHER_SIM"] = "1"
1302 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1"
1303 os.environ["TRITONXPU_INTERLEAVE"] = "0"
1304 global_cumsum_flat_kernel[grid](
1305 ne_result,
1306 tile_sum, # in
1307 sorted_data,
1308 sorted_indices, # in
1309 data_out,
1310 inverse_indices,
1311 idx, # out
1312 cumsum_out,
1313 ctas_num,
1314 global_ctas_num,
1315 next_power_global_ctas_num,
1316 num_tasks,
1317 tiles_per_cta=tiles_per_cta,
1318 tile_size=tile_size,
1319 one_tile_per_cta=tiles_per_cta == 1,
1320 return_counts=return_counts,
1321 num_warps=num_warps,
1322 )
1323 if "TRITONXPU_OTHER_SIM" in os.environ:
1324 del os.environ["TRITONXPU_OTHER_SIM"]
1325 if "TRITONXPU_STORE_MASK_SIM" in os.environ:
1326 del os.environ["TRITONXPU_STORE_MASK_SIM"]
1327 if "TRITONXPU_INTERLEAVE" in os.environ:
1328 del os.environ["TRITONXPU_INTERLEAVE"]
1329 # print(f'cumsum_out = {cumsum_out.cpu()}')
1330 # print(f'out tile_sum.cpu() = {tile_sum.cpu()}')
1332 else:
1333 total_in = torch.cumsum(tile_sum, dim=0)
1334 total_in = torch.roll(total_in, shifts=1)
1335 total_in[0] = 0
1336 # print(f"total_in.shape = {total_in.shape}")
1337 # print(f"total_in.cpu() = {total_in.cpu()}")
1339 # ne_result = torch.cumsum(ne_result, dim=0)
1340 # print(f"ne_result.shape = {ne_result.shape}")
1341 next_multiple = ((num_tasks // 2048) + 1) * 2048
1342 padding_size = next_multiple - num_tasks # 96256 - 96000 = 256
1343 padded_ne_result = torch.nn.functional.pad(
1344 ne_result, (0, padding_size), "constant", 0
1345 )
1346 num_blocks = next_multiple // 2048 # 96256 / 2048 = 47
1347 reshaped = padded_ne_result.view(num_blocks, 2048)
1348 cumsum_blocks = torch.cumsum(reshaped, dim=1)
1349 cumsum_result = cumsum_blocks.view(-1)
1351 # print(f'ne_result.cpu() = {ne_result.cpu()}')
1353 os.environ["TRITONXPU_OTHER_SIM"] = "1"
1354 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1"
1355 os.environ["TRITONXPU_INTERLEAVE"] = "0"
1356 global_cumsum_flat_kernel_stage_1[grid](
1357 ne_result,
1358 tile_sum, # in
1359 sorted_data,
1360 sorted_indices, # in
1361 data_out,
1362 inverse_indices,
1363 idx, # out
1364 total_in,
1365 cumsum_result,
1366 ctas_num,
1367 global_ctas_num,
1368 next_power_global_ctas_num,
1369 num_tasks,
1370 tiles_per_cta=tiles_per_cta,
1371 tile_size=tile_size,
1372 one_tile_per_cta=tiles_per_cta == 1,
1373 return_counts=return_counts,
1374 num_warps=num_warps,
1375 )
1376 if "TRITONXPU_OTHER_SIM" in os.environ:
1377 del os.environ["TRITONXPU_OTHER_SIM"]
1378 if "TRITONXPU_STORE_MASK_SIM" in os.environ:
1379 del os.environ["TRITONXPU_STORE_MASK_SIM"]
1380 if "TRITONXPU_INTERLEAVE" in os.environ:
1381 del os.environ["TRITONXPU_INTERLEAVE"]
1383 # print(f'out tile_sum.cpu() = {tile_sum.cpu()}')
1385 os.environ["TRITONXPU_OTHER_SIM"] = "1"
1386 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1"
1387 os.environ["TRITONXPU_INTERLEAVE"] = "0"
1388 global_cumsum_flat_kernel_stage_2[grid](
1389 ne_result,
1390 tile_sum, # in
1391 sorted_data,
1392 sorted_indices, # in
1393 data_out,
1394 inverse_indices,
1395 idx, # out
1396 total_in,
1397 cumsum_result,
1398 ctas_num,
1399 global_ctas_num,
1400 next_power_global_ctas_num,
1401 num_tasks,
1402 tiles_per_cta=tiles_per_cta,
1403 tile_size=tile_size,
1404 one_tile_per_cta=tiles_per_cta == 1,
1405 return_counts=return_counts,
1406 num_warps=num_warps,
1407 isCloseUnrollControl=True,
1408 )
1409 if "TRITONXPU_OTHER_SIM" in os.environ:
1410 del os.environ["TRITONXPU_OTHER_SIM"]
1411 if "TRITONXPU_STORE_MASK_SIM" in os.environ:
1412 del os.environ["TRITONXPU_STORE_MASK_SIM"]
1413 if "TRITONXPU_INTERLEAVE" in os.environ:
1414 del os.environ["TRITONXPU_INTERLEAVE"]
1416 out_size = tile_sum[-1].item() + 1
1417 counts = None
1418 if return_counts:
1419 idx = idx[:out_size]
1420 counts = torch.empty_like(idx)
1421 # print("i am here!!!!")
1422 os.environ["TRITONXPU_OTHER_SIM"] = "1"
1423 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1"
1424 os.environ["TRITONXPU_INTERLEAVE"] = "0"
1425 output_counts_flat_kernel[grid](
1426 idx,
1427 num_tasks, # in
1428 counts, # out
1429 out_size,
1430 tiles_per_cta,
1431 tile_size,
1432 num_warps=num_warps,
1433 )
1434 if "TRITONXPU_OTHER_SIM" in os.environ:
1435 del os.environ["TRITONXPU_OTHER_SIM"]
1436 if "TRITONXPU_STORE_MASK_SIM" in os.environ:
1437 del os.environ["TRITONXPU_STORE_MASK_SIM"]
1438 if "TRITONXPU_INTERLEAVE" in os.environ:
1439 del os.environ["TRITONXPU_INTERLEAVE"]
1441 return data_out[:out_size], inverse_indices, counts
1444def simple_unique_flat(
1445 sorted_data: torch.Tensor,
1446 sorted_indices: torch.Tensor,
1447 return_inverse: bool,
1448 return_counts: bool,
1449):
1450 num_tasks = sorted_data.numel()
1451 grid = (1, 1, 1)
1453 # allocate tensor
1454 data_out = torch.zeros_like(sorted_data)
1455 if return_inverse:
1456 inverse_indices = torch.zeros_like(sorted_data, dtype=torch.int64)
1457 else:
1458 inverse_indices = None
1459 if return_counts:
1460 idx = torch.zeros_like(sorted_data, dtype=torch.int64)
1461 else:
1462 idx = None
1463 unique_size = torch.zeros([1], dtype=torch.int64, device=sorted_data.device)
1465 # launch kernel
1466 with torch_device_fn.device(sorted_data.device.index):
1467 os.environ["TRITONXPU_OTHER_SIM"] = "1"
1468 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1"
1469 os.environ["TRITONXPU_INTERLEAVE"] = "0"
1470 simple_unique_flat_kernel[grid](
1471 sorted_data,
1472 sorted_indices, # in
1473 data_out,
1474 inverse_indices,
1475 idx,
1476 unique_size, # out
1477 return_inverse,
1478 return_counts,
1479 num_tasks,
1480 tile_size=triton.next_power_of_2(num_tasks),
1481 num_warps=8,
1482 )
1483 if "TRITONXPU_OTHER_SIM" in os.environ:
1484 del os.environ["TRITONXPU_OTHER_SIM"]
1485 if "TRITONXPU_STORE_MASK_SIM" in os.environ:
1486 del os.environ["TRITONXPU_STORE_MASK_SIM"]
1487 if "TRITONXPU_INTERLEAVE" in os.environ:
1488 del os.environ["TRITONXPU_INTERLEAVE"]
1489 out_size = unique_size.item() + 1
1490 # print(f"unique_size.item() = {unique_size.item()}")
1491 counts = None
1492 if return_counts:
1493 idx = idx[:out_size]
1494 counts = torch.empty_like(idx)
1495 with torch_device_fn.device(sorted_data.device.index):
1496 os.environ["TRITONXPU_OTHER_SIM"] = "1"
1497 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1"
1498 os.environ["TRITONXPU_INTERLEAVE"] = "0"
1499 output_counts_flat_kernel[grid](
1500 idx,
1501 num_tasks, # in
1502 counts, # out
1503 num_tasks=out_size,
1504 tiles_per_cta=1,
1505 tile_size=triton.next_power_of_2(out_size),
1506 num_warps=8,
1507 )
1508 if "TRITONXPU_OTHER_SIM" in os.environ:
1509 del os.environ["TRITONXPU_OTHER_SIM"]
1510 if "TRITONXPU_STORE_MASK_SIM" in os.environ:
1511 del os.environ["TRITONXPU_STORE_MASK_SIM"]
1512 if "TRITONXPU_INTERLEAVE" in os.environ:
1513 del os.environ["TRITONXPU_INTERLEAVE"]
1514 return data_out[:out_size], inverse_indices, counts
1517def _unique2(
1518 in0: torch.Tensor,
1519 sorted: bool = True,
1520 return_inverse: bool = False,
1521 return_counts: bool = False,
1522):
1523 if in0.numel() <= 8192:
1524 # print("simple_unique_flat")
1525 sorted_data, sorted_indices = torch.sort(in0.ravel())
1526 data_out, inverse_indices, counts = simple_unique_flat(
1527 sorted_data, sorted_indices, return_inverse, return_counts
1528 )
1529 elif return_inverse:
1530 # print("sorted_indices_unique_flat")
1531 sorted_data, sorted_indices = torch.sort(in0.ravel())
1532 data_out, inverse_indices, counts = sorted_indices_unique_flat(
1533 sorted_data, sorted_indices, return_counts
1534 )
1535 else:
1536 # print("sorted_quick_unique_flat")
1537 sorted_data, _ = torch.sort(in0.ravel())
1538 data_out, inverse_indices, counts = sorted_quick_unique_flat(
1539 sorted_data, return_counts
1540 )
1541 return (
1542 data_out,
1543 inverse_indices if inverse_indices is None else inverse_indices.view_as(in0),
1544 counts,
1545 )