Coverage for src/flag_gems/ops/masked_select.py: 49%

105 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import broadcastable, libentry 

9from flag_gems.utils.shape_utils import bracket_next_power_of_2 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14@libentry() 

15@triton.jit 

16def masked_select_single_pass_kernel( 

17 inp_ptr, mask_ptr, out_ptr, N, BLOCK_SIZE: tl.constexpr 

18): 

19 pid = tl.program_id(0) 

20 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

21 inp = tl.load(inp_ptr + offsets, mask=offsets < N) 

22 mask = tl.load(mask_ptr + offsets, mask=offsets < N).to(tl.int1) 

23 mask_ints = mask.to(tl.int32) 

24 out_offsets = tl.cumsum(mask_ints, axis=0) - 1 

25 

26 tl.store(out_ptr + out_offsets, inp, mask=(offsets < N) & mask) 

27 

28 

29def masked_select_single_pass(inp, mask, out, N): 

30 BLOCK_SIZE = triton.next_power_of_2(N) 

31 if BLOCK_SIZE <= 512: 

32 num_warps = 4 

33 elif BLOCK_SIZE <= 2048: 

34 num_warps = 8 

35 else: 

36 num_warps = 16 

37 masked_select_single_pass_kernel[(1,)]( 

38 inp, mask, out, N, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps 

39 ) 

40 return out 

41 

42 

43@libentry() 

44@triton.jit(do_not_specialize=["N", "nr", "row_stride"]) 

45def mask_part_sum_kernel( 

46 inp_ptr, 

47 mask_ptr, 

48 part_sums_ptr, 

49 counter_ptr, 

50 N, 

51 num_blocks, 

52 num_blocks_per_row, 

53 NP_BLOCK: tl.constexpr, 

54 BLOCK_SIZE: tl.constexpr, 

55): 

56 row_id = tl.program_id(0) 

57 start_block = row_id * num_blocks_per_row 

58 offset = start_block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

59 acc = tl.zeros((BLOCK_SIZE,), dtype=part_sums_ptr.dtype.element_ty) 

60 

61 last_block_id = min(num_blocks - 1, start_block + num_blocks_per_row - 1) 

62 

63 for block_id in range(start_block, last_block_id): 

64 select = tl.load(mask_ptr + offset) 

65 select_ints = select.to(part_sums_ptr.dtype.element_ty) 

66 acc += select_ints 

67 offset += BLOCK_SIZE 

68 # Peeled last block 

69 select = tl.load(mask_ptr + offset, mask=offset < N, other=0) 

70 select_ints = select.to(part_sums_ptr.dtype.element_ty) 

71 acc += select_ints 

72 

73 part_sum = tl.sum(acc, axis=0) 

74 tl.store(part_sums_ptr + row_id, part_sum) 

75 # cumsum the part_sums 

76 count = tl.atomic_add(counter_ptr, 1, sem="acq_rel") 

77 np = tl.num_programs(0) 

78 if count == np - 1: 

79 mask = tl.arange(0, NP_BLOCK) < np 

80 part_sums = tl.load(part_sums_ptr + tl.arange(0, NP_BLOCK), mask=mask) 

81 final_sum = tl.sum(part_sums, axis=0) 

82 pre_sums = tl.cumsum(part_sums, axis=0) 

83 tl.store( 

84 part_sums_ptr + tl.arange(0, NP_BLOCK), pre_sums - part_sums, mask=mask 

85 ) 

86 tl.store(part_sums_ptr + np, final_sum) 

87 

88 

89@libentry() 

90@triton.jit(do_not_specialize=["N", "nr", "row_stride"]) 

91def write_back_kernel( 

92 inp_ptr, 

93 mask_ptr, 

94 part_sums_ptr, 

95 out_ptr, 

96 N, 

97 num_blocks, 

98 num_blocks_per_row, 

99 NP_BLOCK: tl.constexpr, 

100 BLOCK_SIZE: tl.constexpr, 

101): 

102 row_id = tl.program_id(0) 

103 

104 start_block = row_id * num_blocks_per_row 

105 offset = start_block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

106 advance = tl.load(part_sums_ptr + row_id) 

107 

108 last_block_id = min(num_blocks - 1, start_block + num_blocks_per_row - 1) 

109 

110 for block_id in range(start_block, last_block_id): 

111 inp = tl.load(inp_ptr + offset) 

112 select_mask = tl.load(mask_ptr + offset).to(tl.int1) 

113 select_ints = select_mask.to(tl.constexpr(part_sums_ptr.dtype.element_ty)) 

114 out_ptr += advance 

115 advance = tl.sum(select_ints, axis=0) 

116 pre_sums = tl.cumsum(select_ints, axis=0) - 1 

117 tl.store(out_ptr + pre_sums, inp, mask=select_mask) 

118 offset += BLOCK_SIZE 

119 # Peeled last block 

120 inp = tl.load(inp_ptr + offset, mask=offset < N) 

121 select_mask = tl.load(mask_ptr + offset, mask=offset < N, other=0).to(tl.int1) 

122 select_ints = select_mask.to(tl.constexpr(part_sums_ptr.dtype.element_ty)) 

123 out_ptr += advance 

124 pre_sums = tl.cumsum(select_ints, axis=0) - 1 

125 tl.store(out_ptr + pre_sums, inp, mask=(offset < N) & select_mask) 

126 

127 

128def masked_select(inp, mask): 

129 logger.debug("GEMS MASKED SELECT") 

130 

131 inp_shape = tuple(inp.shape) 

132 mask_shape = tuple(mask.shape) 

133 

134 assert broadcastable( 

135 inp_shape, mask_shape 

136 ), "The shapes of the `mask` and the `input` tensor must be broadcastable" 

137 inp, mask = torch.broadcast_tensors(inp, mask) 

138 

139 inp = inp.contiguous() 

140 mask = mask.contiguous() 

141 

142 N = inp.numel() 

143 if N <= 4096: 

144 out = torch.empty(mask.sum(), dtype=inp.dtype, device=inp.device) 

145 return masked_select_single_pass(inp, mask, out, N) 

146 

147 # return mask_select(inp, mask) 

148 

149 BLOCK_SIZE = bracket_next_power_of_2(N, 128, 4096) 

150 num_warps = min(16, BLOCK_SIZE // 32) 

151 

152 # max degree of parallelism 

153 np = torch_device_fn.get_device_properties(mask.device).multi_processor_count 

154 

155 # arranged as np rows of blocks 

156 n_blocks = triton.cdiv(N, BLOCK_SIZE) 

157 np = min(n_blocks, np) 

158 n_blocks_per_row = triton.cdiv(n_blocks, np) 

159 np = triton.cdiv(n_blocks, n_blocks_per_row) 

160 NP_BLOCK = triton.next_power_of_2(np) 

161 

162 with torch_device_fn.device(inp.device): 

163 # Compute per cta sums and cumulative sums across ctas 

164 dtype = torch.int32 if N < 2**31 else torch.int64 

165 part_sums = torch.empty(np + 1, dtype=dtype, device=mask.device) 

166 barrier = torch.zeros([], dtype=torch.int, device=mask.device) 

167 mask_part_sum_kernel[(np,)]( 

168 inp, 

169 mask, 

170 part_sums, 

171 barrier, 

172 N, 

173 n_blocks, 

174 n_blocks_per_row, 

175 NP_BLOCK=NP_BLOCK, 

176 BLOCK_SIZE=BLOCK_SIZE, 

177 num_warps=num_warps, 

178 ) 

179 

180 # Write back selected data 

181 out = torch.empty(part_sums[-1], dtype=inp.dtype, device=mask.device) 

182 # write_offsets = pre_sums - part_sums 

183 write_back_kernel[(np,)]( 

184 inp, 

185 mask, 

186 part_sums, 

187 out, 

188 N, 

189 n_blocks, 

190 n_blocks_per_row, 

191 NP_BLOCK=triton.next_power_of_2(np), 

192 BLOCK_SIZE=BLOCK_SIZE, 

193 num_warps=num_warps, 

194 ) 

195 

196 return out