Coverage for src/flag_gems/runtime/backend/_cambricon/utils/reduce_utils.py: 0%
42 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
1import math
3import triton
5from flag_gems import runtime
7from . import MAX_NRAM_SIZE, TOTAL_CORE_NUM
10def cfggen_reduce_op():
11 return runtime.get_tuned_config("common_reduce_ops")
14def cfggen_reduce_op2():
15 block_size = [2048, 4096, 8192, 16384, 32768]
16 num_stage = [1, 3]
17 configs = [
18 triton.Config(
19 {"BLOCK_SIZE": m, "ITER_NUM": math.log2(m) + 1}, num_warps=1, num_stages=s
20 )
21 for m in block_size
22 for s in num_stage
23 ]
24 return configs
27def count_divisible_by_2(x):
28 count = 0
29 while x > 0 and x % 2 == 0:
30 x //= 2
31 count += 1
32 return count
35def next_power_of_two(x):
36 if x < 16:
37 return 16
38 if x & (x - 1) == 0:
39 return x
40 return 1 << (x - 1).bit_length()
43def prune_reduce_config(configs, named_args, **kwargs):
44 M = named_args["M"]
45 pruned_configs = []
46 for config in configs:
47 BLOCK_SIZE = config.kwargs["BLOCK_SIZE"]
48 num_stages = config.num_stages
49 num_block = M // BLOCK_SIZE
50 if num_block < 1:
51 continue
52 if num_block < TOTAL_CORE_NUM:
53 # A core must process a BLOCK_SIZE of data.
54 if num_stages > 1:
55 continue
56 # The final IR will only have two allocs of BLOCK_SIZE:
57 # - one for the pad generated by the mask load;
58 # - one for for the dst of computation;
59 alloc_num = 2
60 else:
61 # A core may process more than one BLOCK_SIZE of data.
62 # The final IR will only have four allocs of BLOCK_SIZE:
63 # - one for the _tmp to receive the value;
64 # - one for the pad generated by the mask load;
65 # - one for for the dst of computation;
66 # - one for the return value of for.
67 alloc_num = 4
68 # Set f32 as the default type.
69 if BLOCK_SIZE * 4 * alloc_num <= MAX_NRAM_SIZE:
70 pruned_configs.append(config)
71 # If M < 1024, append the default config.
72 if len(pruned_configs) == 0:
73 pruned_configs.append(
74 triton.Config(
75 {"BLOCK_SIZE": next_power_of_two(M)}, num_warps=1, num_stages=1
76 )
77 )
78 return pruned_configs