Coverage for src/flag_gems/runtime/backend/_cambricon/utils/reduce_utils.py: 0%

42 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-25 02:48 +0800

1import math 

2 

3import triton 

4 

5from flag_gems import runtime 

6 

7from . import MAX_NRAM_SIZE, TOTAL_CORE_NUM 

8 

9 

10def cfggen_reduce_op(): 

11 return runtime.get_tuned_config("common_reduce_ops") 

12 

13 

14def cfggen_reduce_op2(): 

15 block_size = [2048, 4096, 8192, 16384, 32768] 

16 num_stage = [1, 3] 

17 configs = [ 

18 triton.Config( 

19 {"BLOCK_SIZE": m, "ITER_NUM": math.log2(m) + 1}, num_warps=1, num_stages=s 

20 ) 

21 for m in block_size 

22 for s in num_stage 

23 ] 

24 return configs 

25 

26 

27def count_divisible_by_2(x): 

28 count = 0 

29 while x > 0 and x % 2 == 0: 

30 x //= 2 

31 count += 1 

32 return count 

33 

34 

35def next_power_of_two(x): 

36 if x < 16: 

37 return 16 

38 if x & (x - 1) == 0: 

39 return x 

40 return 1 << (x - 1).bit_length() 

41 

42 

43def prune_reduce_config(configs, named_args, **kwargs): 

44 M = named_args["M"] 

45 pruned_configs = [] 

46 for config in configs: 

47 BLOCK_SIZE = config.kwargs["BLOCK_SIZE"] 

48 num_stages = config.num_stages 

49 num_block = M // BLOCK_SIZE 

50 if num_block < 1: 

51 continue 

52 if num_block < TOTAL_CORE_NUM: 

53 # A core must process a BLOCK_SIZE of data. 

54 if num_stages > 1: 

55 continue 

56 # The final IR will only have two allocs of BLOCK_SIZE: 

57 # - one for the pad generated by the mask load; 

58 # - one for for the dst of computation; 

59 alloc_num = 2 

60 else: 

61 # A core may process more than one BLOCK_SIZE of data. 

62 # The final IR will only have four allocs of BLOCK_SIZE: 

63 # - one for the _tmp to receive the value; 

64 # - one for the pad generated by the mask load; 

65 # - one for for the dst of computation; 

66 # - one for the return value of for. 

67 alloc_num = 4 

68 # Set f32 as the default type. 

69 if BLOCK_SIZE * 4 * alloc_num <= MAX_NRAM_SIZE: 

70 pruned_configs.append(config) 

71 # If M < 1024, append the default config. 

72 if len(pruned_configs) == 0: 

73 pruned_configs.append( 

74 triton.Config( 

75 {"BLOCK_SIZE": next_power_of_two(M)}, num_warps=1, num_stages=1 

76 ) 

77 ) 

78 return pruned_configs