Coverage for src/flag_gems/fused/bincount.py: 51%

120 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7logger = logging.getLogger(__name__) 

8 

9 

10def _select_params(n): 

11 if n <= 256: 

12 return 256, 2 

13 if n <= 1024: 

14 return 256, 4 

15 if n <= 4096: 

16 return 512, 4 

17 return 1024, 4 

18 

19 

20def _estimate_output_size(n, minlength): 

21 estimate = max(8192, n * 4, minlength) 

22 estimate = min(estimate, 65536) 

23 return max(estimate, minlength) 

24 

25 

26@triton.jit 

27def fused_max_bincount_kernel( 

28 input_ptr, 

29 max_ptr, 

30 output_ptr, 

31 n_elements, 

32 output_size, 

33 BLOCK_SIZE: tl.constexpr, 

34): 

35 pid = tl.program_id(0) 

36 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

37 mask = offsets < n_elements 

38 vals = tl.load(input_ptr + offsets, mask=mask, other=0) 

39 

40 local_max = tl.max(vals, axis=0) 

41 tl.atomic_max(max_ptr, local_max) 

42 

43 safe_mask = mask & (vals < output_size) 

44 tl.atomic_add(output_ptr + vals, 1, mask=safe_mask) 

45 

46 

47@triton.jit 

48def bincount_kernel( 

49 input_ptr, 

50 output_ptr, 

51 n_elements, 

52 BLOCK_SIZE: tl.constexpr, 

53): 

54 pid = tl.program_id(0) 

55 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

56 mask = offsets < n_elements 

57 vals = tl.load(input_ptr + offsets, mask=mask, other=0) 

58 tl.atomic_add(output_ptr + vals, 1, mask=mask) 

59 

60 

61@triton.jit 

62def fused_max_bincount_weights_fp32_kernel( 

63 input_ptr, 

64 weights_ptr, 

65 max_ptr, 

66 output_ptr, 

67 n_elements, 

68 output_size, 

69 BLOCK_SIZE: tl.constexpr, 

70): 

71 pid = tl.program_id(0) 

72 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

73 mask = offsets < n_elements 

74 vals = tl.load(input_ptr + offsets, mask=mask, other=0) 

75 w = tl.load(weights_ptr + offsets, mask=mask, other=0.0) 

76 w_fp32 = w.to(tl.float32) 

77 

78 local_max = tl.max(vals, axis=0) 

79 tl.atomic_max(max_ptr, local_max) 

80 

81 safe_mask = mask & (vals < output_size) 

82 tl.atomic_add(output_ptr + vals, w_fp32, mask=safe_mask) 

83 

84 

85@triton.jit 

86def bincount_weights_fp32_kernel( 

87 input_ptr, 

88 weights_ptr, 

89 output_ptr, 

90 n_elements, 

91 BLOCK_SIZE: tl.constexpr, 

92): 

93 pid = tl.program_id(0) 

94 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

95 mask = offsets < n_elements 

96 vals = tl.load(input_ptr + offsets, mask=mask, other=0) 

97 w = tl.load(weights_ptr + offsets, mask=mask, other=0.0) 

98 w_fp32 = w.to(tl.float32) 

99 tl.atomic_add(output_ptr + vals, w_fp32, mask=mask) 

100 

101 

102@triton.jit 

103def fused_max_bincount_weights_fp64_kernel( 

104 input_ptr, 

105 weights_ptr, 

106 max_ptr, 

107 output_ptr, 

108 n_elements, 

109 output_size, 

110 BLOCK_SIZE: tl.constexpr, 

111): 

112 pid = tl.program_id(0) 

113 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

114 mask = offsets < n_elements 

115 vals = tl.load(input_ptr + offsets, mask=mask, other=0) 

116 w = tl.load(weights_ptr + offsets, mask=mask, other=0.0) 

117 w_fp64 = w.to(tl.float64) 

118 

119 local_max = tl.max(vals, axis=0) 

120 tl.atomic_max(max_ptr, local_max) 

121 

122 safe_mask = mask & (vals < output_size) 

123 tl.atomic_add(output_ptr + vals, w_fp64, mask=safe_mask) 

124 

125 

126@triton.jit 

127def bincount_weights_fp64_kernel( 

128 input_ptr, 

129 weights_ptr, 

130 output_ptr, 

131 n_elements, 

132 BLOCK_SIZE: tl.constexpr, 

133): 

134 pid = tl.program_id(0) 

135 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

136 mask = offsets < n_elements 

137 vals = tl.load(input_ptr + offsets, mask=mask, other=0) 

138 w = tl.load(weights_ptr + offsets, mask=mask, other=0.0) 

139 w_fp64 = w.to(tl.float64) 

140 tl.atomic_add(output_ptr + vals, w_fp64, mask=mask) 

141 

142 

143def _fused_bincount_launch( 

144 input_contig, 

145 weights_contig, 

146 n, 

147 pre_size, 

148 minlength, 

149 out_dtype, 

150 grid, 

151 BLOCK_SIZE, 

152 num_warps, 

153): 

154 max_tensor = torch.zeros(1, dtype=torch.int64, device=input_contig.device) 

155 is_fp64 = out_dtype == torch.float64 

156 compute_dtype = ( 

157 torch.float64 

158 if is_fp64 

159 else (torch.float32 if weights_contig is not None else torch.int64) 

160 ) 

161 if weights_contig is None: 

162 compute_dtype = torch.int64 

163 

164 output = torch.zeros(pre_size, dtype=compute_dtype, device=input_contig.device) 

165 

166 if weights_contig is None: 

167 fused_max_bincount_kernel[grid]( 

168 input_contig, 

169 max_tensor, 

170 output, 

171 n, 

172 pre_size, 

173 BLOCK_SIZE=BLOCK_SIZE, 

174 num_warps=num_warps, 

175 ) 

176 elif is_fp64: 

177 fused_max_bincount_weights_fp64_kernel[grid]( 

178 input_contig, 

179 weights_contig, 

180 max_tensor, 

181 output, 

182 n, 

183 pre_size, 

184 BLOCK_SIZE=BLOCK_SIZE, 

185 num_warps=num_warps, 

186 ) 

187 else: 

188 fused_max_bincount_weights_fp32_kernel[grid]( 

189 input_contig, 

190 weights_contig, 

191 max_tensor, 

192 output, 

193 n, 

194 pre_size, 

195 BLOCK_SIZE=BLOCK_SIZE, 

196 num_warps=num_warps, 

197 ) 

198 

199 max_val = int(max_tensor.item()) 

200 needed_size = max(max_val + 1, minlength) 

201 

202 if needed_size <= pre_size: 

203 return output[:needed_size] 

204 

205 output = torch.zeros(needed_size, dtype=compute_dtype, device=input_contig.device) 

206 if weights_contig is None: 

207 bincount_kernel[grid]( 

208 input_contig, 

209 output, 

210 n, 

211 BLOCK_SIZE=BLOCK_SIZE, 

212 num_warps=num_warps, 

213 ) 

214 elif is_fp64: 

215 bincount_weights_fp64_kernel[grid]( 

216 input_contig, 

217 weights_contig, 

218 output, 

219 n, 

220 BLOCK_SIZE=BLOCK_SIZE, 

221 num_warps=num_warps, 

222 ) 

223 else: 

224 bincount_weights_fp32_kernel[grid]( 

225 input_contig, 

226 weights_contig, 

227 output, 

228 n, 

229 BLOCK_SIZE=BLOCK_SIZE, 

230 num_warps=num_warps, 

231 ) 

232 return output 

233 

234 

235def bincount(input, weights=None, minlength=0): 

236 logger.debug("GEMS BINCOUNT") 

237 

238 assert input.dim() == 1, "input must be a 1-D tensor" 

239 assert minlength >= 0, "minlength must be non-negative" 

240 

241 if weights is not None: 

242 assert weights.shape == input.shape, "weights must have the same shape as input" 

243 

244 n = input.numel() 

245 

246 if n == 0: 

247 if weights is not None: 

248 return torch.zeros(minlength, dtype=weights.dtype, device=input.device) 

249 return torch.zeros(minlength, dtype=torch.int64, device=input.device) 

250 

251 input_contig = input.contiguous() 

252 weights_contig = weights.contiguous() if weights is not None else None 

253 

254 BLOCK_SIZE, num_warps = _select_params(n) 

255 grid = (triton.cdiv(n, BLOCK_SIZE),) 

256 

257 pre_size = _estimate_output_size(n, minlength) 

258 

259 out_dtype = weights.dtype if weights is not None else torch.int64 

260 

261 output = _fused_bincount_launch( 

262 input_contig, 

263 weights_contig, 

264 n, 

265 pre_size, 

266 minlength, 

267 out_dtype, 

268 grid, 

269 BLOCK_SIZE, 

270 num_warps, 

271 ) 

272 

273 if ( 

274 weights is not None 

275 and weights.dtype != torch.float64 

276 and weights.dtype != torch.float32 

277 ): 

278 output = output.to(dtype=weights.dtype) 

279 

280 return output