Coverage for src/flag_gems/runtime/backend/_cambricon/ops/masked_select.py: 0%

73 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-09 01:57 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import broadcastable 

10 

11from ..utils import TOTAL_CORE_NUM 

12 

13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

14 

15 

16@triton.autotune(configs=runtime.get_tuned_config("masked_select"), key=["n_elements"]) 

17@triton.jit 

18def masked_select_kernel( 

19 inp_ptr, 

20 select_mask_ptr, 

21 select_val_ptr, 

22 select_num_ptr, 

23 n_elements, 

24 BLOCK_SIZE: tl.constexpr, 

25): 

26 pid = tl.program_id(axis=0) 

27 num_p = tl.num_programs(axis=0) 

28 split_n = (n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE 

29 step = BLOCK_SIZE * num_p 

30 offset_start = pid * BLOCK_SIZE 

31 loop = 0 

32 for offset in tl.range(offset_start, n_elements, step): 

33 offsets = offset + tl.arange(0, BLOCK_SIZE) 

34 mask = offsets < n_elements 

35 inp = tl.load(inp_ptr + offsets, mask=mask, other=0.0) 

36 select_mask = tl.load(select_mask_ptr + offsets, mask=mask, other=0.0).to( 

37 tl.int1 

38 ) 

39 select_val, select_num = tl.masked_select(inp, select_mask) 

40 tl.store(select_val_ptr + offsets, select_val, mask=mask) 

41 num_select_offset = loop * num_p + pid + tl.arange(0, 1) 

42 loop += 1 

43 num_select_mask = num_select_offset < split_n 

44 tl.store(select_num_ptr + num_select_offset, select_num, mask=num_select_mask) 

45 

46 

47@triton.jit 

48def get_out_kernel( 

49 select_val_ptr, 

50 select_num_ptr, 

51 output_ptr, 

52 n_elements: tl.constexpr, 

53 BLOCK_SIZE: tl.constexpr, 

54): 

55 pid = tl.program_id(axis=0) 

56 num_p = tl.num_programs(axis=0) 

57 step = BLOCK_SIZE * num_p 

58 split_n: tl.constexpr = (n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE 

59 

60 all_select_num_offset = tl.arange(0, split_n) 

61 all_select_num_mask = all_select_num_offset < split_n 

62 all_select_num = tl.load( 

63 select_num_ptr + all_select_num_offset, mask=all_select_num_mask, other=0.0 

64 ) 

65 prefix_select_num = tl.cumsum(all_select_num, 0) 

66 

67 offset_start = pid * BLOCK_SIZE 

68 loop = 0 

69 for offset in tl.range(offset_start, n_elements, step): 

70 offsets = offset + tl.arange(0, BLOCK_SIZE) 

71 mask = offsets < n_elements 

72 select_val = tl.load(select_val_ptr + offsets, mask=mask, other=0.0) 

73 select_num_offset = loop * num_p + pid + tl.arange(0, 1) 

74 select_num_mask = select_num_offset < split_n 

75 select_num = tl.load( 

76 select_num_ptr + select_num_offset, mask=select_num_mask, other=0.0 

77 ) 

78 if loop == 0 and pid == 0: 

79 output_offset = tl.arange(0, BLOCK_SIZE) 

80 else: 

81 output_offset = prefix_select_num[loop * num_p + pid - 1] + tl.arange( 

82 0, BLOCK_SIZE 

83 ) 

84 loop += 1 

85 output_mask = tl.arange(0, BLOCK_SIZE) < select_num 

86 tl.store(output_ptr + output_offset, select_val, mask=output_mask) 

87 

88 

89def masked_select(inp, mask): 

90 logger.debug("GEMS_CAMBRICON MASKED SELECT") 

91 

92 inp_shape = tuple(inp.shape) 

93 mask_shape = tuple(mask.shape) 

94 

95 assert broadcastable( 

96 inp_shape, mask_shape 

97 ), "The shapes of the `mask` and the `input` tensor must be broadcastable" 

98 inp, mask = torch.broadcast_tensors(inp, mask) 

99 

100 inp = inp.contiguous() 

101 mask = mask.contiguous() 

102 

103 n_elements = inp.numel() 

104 grid = lambda meta: ( 

105 min(triton.cdiv(n_elements, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM), 

106 ) 

107 with torch_device_fn.device(inp.device): 

108 select_val = torch.empty(n_elements, dtype=inp.dtype, device=inp.device) 

109 select_num = torch.empty(n_elements, dtype=torch.int32, device=inp.device) 

110 masked_select_kernel[grid](inp, mask, select_val, select_num, n_elements) 

111 

112 cur_block_size = masked_select_kernel.best_config.kwargs["BLOCK_SIZE"] 

113 num_select = mask.sum().item() 

114 output = torch.empty(num_select, dtype=inp.dtype, device=inp.device) 

115 get_out_kernel[grid](select_val, select_num, output, n_elements, cur_block_size) 

116 

117 return output