Coverage for src/flag_gems/ops/masked_scatter.py: 50%

122 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-12 02:21 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import broadcastable, libentry 

9from flag_gems.utils.shape_utils import bracket_next_power_of_2 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14@libentry() 

15@triton.jit 

16def masked_scatter_single_pass_kernel( 

17 inp_ptr, mask_ptr, src_ptr, N, BLOCK_SIZE: tl.constexpr 

18): 

19 pid = tl.program_id(0) 

20 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

21 

22 block_mask = offsets < N 

23 

24 mask_val = tl.load(mask_ptr + offsets, mask=block_mask, other=0).to(tl.int1) 

25 

26 mask_ints = mask_val.to(tl.int32) 

27 src_indices = tl.cumsum(mask_ints, axis=0) - 1 

28 

29 active = block_mask & mask_val 

30 src_val = tl.load(src_ptr + src_indices, mask=active) 

31 tl.store(inp_ptr + offsets, src_val, mask=active) 

32 

33 

34@libentry() 

35@triton.jit(do_not_specialize=["N", "num_blocks", "num_blocks_per_row"]) 

36def mask_part_sum_kernel( 

37 mask_ptr, 

38 part_sums_ptr, 

39 counter_ptr, 

40 N, 

41 num_blocks, 

42 num_blocks_per_row, 

43 NP_BLOCK: tl.constexpr, 

44 BLOCK_SIZE: tl.constexpr, 

45): 

46 row_id = tl.program_id(0) 

47 start_block = row_id * num_blocks_per_row 

48 offset = start_block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

49 acc = tl.zeros((BLOCK_SIZE,), dtype=part_sums_ptr.dtype.element_ty) 

50 

51 last_block_id = min(num_blocks - 1, start_block + num_blocks_per_row - 1) 

52 

53 for block_id in range(start_block, last_block_id): 

54 select = tl.load(mask_ptr + offset) 

55 select_ints = select.to(part_sums_ptr.dtype.element_ty) 

56 acc += select_ints 

57 offset += BLOCK_SIZE 

58 

59 select = tl.load(mask_ptr + offset, mask=offset < N, other=0) 

60 select_ints = select.to(part_sums_ptr.dtype.element_ty) 

61 acc += select_ints 

62 

63 part_sum = tl.sum(acc, axis=0) 

64 tl.store(part_sums_ptr + row_id, part_sum) 

65 

66 count = tl.atomic_add(counter_ptr, 1, sem="acq_rel") 

67 np = tl.num_programs(0) 

68 

69 if count == np - 1: 

70 mask = tl.arange(0, NP_BLOCK) < np 

71 part_sums = tl.load(part_sums_ptr + tl.arange(0, NP_BLOCK), mask=mask) 

72 final_sum = tl.sum(part_sums, axis=0) 

73 pre_sums = tl.cumsum(part_sums, axis=0) 

74 tl.store( 

75 part_sums_ptr + tl.arange(0, NP_BLOCK), pre_sums - part_sums, mask=mask 

76 ) 

77 tl.store(part_sums_ptr + np, final_sum) 

78 

79 

80@libentry() 

81@triton.jit(do_not_specialize=["N", "num_blocks", "num_blocks_per_row"]) 

82def masked_scatter_kernel( 

83 inp_ptr, 

84 mask_ptr, 

85 src_ptr, 

86 part_sums_ptr, 

87 N, 

88 num_blocks, 

89 num_blocks_per_row, 

90 BLOCK_SIZE: tl.constexpr, 

91): 

92 row_id = tl.program_id(0) 

93 

94 start_block = row_id * num_blocks_per_row 

95 offset = start_block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

96 

97 advance = tl.load(part_sums_ptr + row_id) 

98 

99 last_block_id = min(num_blocks - 1, start_block + num_blocks_per_row - 1) 

100 

101 for block_id in range(start_block, last_block_id): 

102 select_mask = tl.load(mask_ptr + offset).to(tl.int1) 

103 select_ints = select_mask.to(tl.int32) 

104 

105 block_cumsum = tl.cumsum(select_ints, axis=0) - 1 

106 global_src_idx = advance + block_cumsum 

107 

108 advance += tl.sum(select_ints, axis=0) 

109 

110 src_val = tl.load(src_ptr + global_src_idx, mask=select_mask) 

111 tl.store(inp_ptr + offset, src_val, mask=select_mask) 

112 

113 offset += BLOCK_SIZE 

114 

115 block_mask = offset < N 

116 select_mask = tl.load(mask_ptr + offset, mask=block_mask, other=0).to(tl.int1) 

117 

118 select_ints = select_mask.to(tl.int32) 

119 block_cumsum = tl.cumsum(select_ints, axis=0) - 1 

120 global_src_idx = advance + block_cumsum 

121 

122 active = block_mask & select_mask 

123 src_val = tl.load(src_ptr + global_src_idx, mask=active) 

124 tl.store(inp_ptr + offset, src_val, mask=active) 

125 

126 

127def masked_scatter_impl(inp, mask, source, N): 

128 if N <= 4096: 

129 BLOCK_SIZE = triton.next_power_of_2(N) 

130 num_warps = 4 

131 if BLOCK_SIZE >= 2048: 

132 num_warps = 8 

133 if BLOCK_SIZE >= 4096: 

134 num_warps = 16 

135 

136 masked_scatter_single_pass_kernel[(1,)]( 

137 inp, mask, source, N, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps 

138 ) 

139 return inp 

140 

141 BLOCK_SIZE = bracket_next_power_of_2(N, 128, 4096) 

142 num_warps = min(16, BLOCK_SIZE // 32) 

143 

144 np = torch_device_fn.get_device_properties(mask.device).multi_processor_count 

145 n_blocks = triton.cdiv(N, BLOCK_SIZE) 

146 np = min(n_blocks, np) 

147 n_blocks_per_row = triton.cdiv(n_blocks, np) 

148 np = triton.cdiv(n_blocks, n_blocks_per_row) 

149 NP_BLOCK = triton.next_power_of_2(np) 

150 

151 with torch_device_fn.device(inp.device): 

152 dtype = torch.int32 if N < 2**31 else torch.int64 

153 part_sums = torch.empty(np + 1, dtype=dtype, device=mask.device) 

154 barrier = torch.zeros([], dtype=torch.int, device=mask.device) 

155 

156 mask_part_sum_kernel[(np,)]( 

157 mask, 

158 part_sums, 

159 barrier, 

160 N, 

161 n_blocks, 

162 n_blocks_per_row, 

163 NP_BLOCK=NP_BLOCK, 

164 BLOCK_SIZE=BLOCK_SIZE, 

165 num_warps=num_warps, 

166 ) 

167 

168 masked_scatter_kernel[(np,)]( 

169 inp, 

170 mask, 

171 source, 

172 part_sums, 

173 N, 

174 n_blocks, 

175 n_blocks_per_row, 

176 BLOCK_SIZE=BLOCK_SIZE, 

177 num_warps=num_warps, 

178 ) 

179 

180 return inp 

181 

182 

183def masked_scatter(inp, mask, source): 

184 logger.debug("GEMS MASKED SCATTER") 

185 

186 assert broadcastable( 

187 inp.shape, mask.shape 

188 ), "The shapes of the `mask` and the `input` tensor must be broadcastable" 

189 

190 _, mask = torch.broadcast_tensors(inp, mask) 

191 

192 out = inp.clone() 

193 if not out.is_contiguous(): 

194 out = out.contiguous() 

195 if not mask.is_contiguous(): 

196 mask = mask.contiguous() 

197 if not source.is_contiguous(): 

198 source = source.contiguous() 

199 

200 N = out.numel() 

201 

202 masked_scatter_impl(out, mask, source, N) 

203 

204 return out 

205 

206 

207def masked_scatter_(inp, mask, source): 

208 logger.debug("GEMS MASKED SCATTER_") 

209 

210 assert broadcastable(inp.shape, mask.shape) 

211 _, mask = torch.broadcast_tensors(inp, mask) 

212 

213 if not inp.is_contiguous(): 

214 raise RuntimeError( 

215 "in-place operation currently requires contiguous input tensor. " 

216 ) 

217 

218 mask = mask if mask.is_contiguous() else mask.contiguous() 

219 source = source if source.is_contiguous() else source.contiguous() 

220 

221 N = inp.numel() 

222 masked_scatter_impl(inp, mask, source, N) 

223 

224 return inp