Coverage for src/flag_gems/fused/bincount.py: 51%
120 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7logger = logging.getLogger(__name__)
10def _select_params(n):
11 if n <= 256:
12 return 256, 2
13 if n <= 1024:
14 return 256, 4
15 if n <= 4096:
16 return 512, 4
17 return 1024, 4
20def _estimate_output_size(n, minlength):
21 estimate = max(8192, n * 4, minlength)
22 estimate = min(estimate, 65536)
23 return max(estimate, minlength)
26@triton.jit
27def fused_max_bincount_kernel(
28 input_ptr,
29 max_ptr,
30 output_ptr,
31 n_elements,
32 output_size,
33 BLOCK_SIZE: tl.constexpr,
34):
35 pid = tl.program_id(0)
36 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
37 mask = offsets < n_elements
38 vals = tl.load(input_ptr + offsets, mask=mask, other=0)
40 local_max = tl.max(vals, axis=0)
41 tl.atomic_max(max_ptr, local_max)
43 safe_mask = mask & (vals < output_size)
44 tl.atomic_add(output_ptr + vals, 1, mask=safe_mask)
47@triton.jit
48def bincount_kernel(
49 input_ptr,
50 output_ptr,
51 n_elements,
52 BLOCK_SIZE: tl.constexpr,
53):
54 pid = tl.program_id(0)
55 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
56 mask = offsets < n_elements
57 vals = tl.load(input_ptr + offsets, mask=mask, other=0)
58 tl.atomic_add(output_ptr + vals, 1, mask=mask)
61@triton.jit
62def fused_max_bincount_weights_fp32_kernel(
63 input_ptr,
64 weights_ptr,
65 max_ptr,
66 output_ptr,
67 n_elements,
68 output_size,
69 BLOCK_SIZE: tl.constexpr,
70):
71 pid = tl.program_id(0)
72 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
73 mask = offsets < n_elements
74 vals = tl.load(input_ptr + offsets, mask=mask, other=0)
75 w = tl.load(weights_ptr + offsets, mask=mask, other=0.0)
76 w_fp32 = w.to(tl.float32)
78 local_max = tl.max(vals, axis=0)
79 tl.atomic_max(max_ptr, local_max)
81 safe_mask = mask & (vals < output_size)
82 tl.atomic_add(output_ptr + vals, w_fp32, mask=safe_mask)
85@triton.jit
86def bincount_weights_fp32_kernel(
87 input_ptr,
88 weights_ptr,
89 output_ptr,
90 n_elements,
91 BLOCK_SIZE: tl.constexpr,
92):
93 pid = tl.program_id(0)
94 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
95 mask = offsets < n_elements
96 vals = tl.load(input_ptr + offsets, mask=mask, other=0)
97 w = tl.load(weights_ptr + offsets, mask=mask, other=0.0)
98 w_fp32 = w.to(tl.float32)
99 tl.atomic_add(output_ptr + vals, w_fp32, mask=mask)
102@triton.jit
103def fused_max_bincount_weights_fp64_kernel(
104 input_ptr,
105 weights_ptr,
106 max_ptr,
107 output_ptr,
108 n_elements,
109 output_size,
110 BLOCK_SIZE: tl.constexpr,
111):
112 pid = tl.program_id(0)
113 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
114 mask = offsets < n_elements
115 vals = tl.load(input_ptr + offsets, mask=mask, other=0)
116 w = tl.load(weights_ptr + offsets, mask=mask, other=0.0)
117 w_fp64 = w.to(tl.float64)
119 local_max = tl.max(vals, axis=0)
120 tl.atomic_max(max_ptr, local_max)
122 safe_mask = mask & (vals < output_size)
123 tl.atomic_add(output_ptr + vals, w_fp64, mask=safe_mask)
126@triton.jit
127def bincount_weights_fp64_kernel(
128 input_ptr,
129 weights_ptr,
130 output_ptr,
131 n_elements,
132 BLOCK_SIZE: tl.constexpr,
133):
134 pid = tl.program_id(0)
135 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
136 mask = offsets < n_elements
137 vals = tl.load(input_ptr + offsets, mask=mask, other=0)
138 w = tl.load(weights_ptr + offsets, mask=mask, other=0.0)
139 w_fp64 = w.to(tl.float64)
140 tl.atomic_add(output_ptr + vals, w_fp64, mask=mask)
143def _fused_bincount_launch(
144 input_contig,
145 weights_contig,
146 n,
147 pre_size,
148 minlength,
149 out_dtype,
150 grid,
151 BLOCK_SIZE,
152 num_warps,
153):
154 max_tensor = torch.zeros(1, dtype=torch.int64, device=input_contig.device)
155 is_fp64 = out_dtype == torch.float64
156 compute_dtype = (
157 torch.float64
158 if is_fp64
159 else (torch.float32 if weights_contig is not None else torch.int64)
160 )
161 if weights_contig is None:
162 compute_dtype = torch.int64
164 output = torch.zeros(pre_size, dtype=compute_dtype, device=input_contig.device)
166 if weights_contig is None:
167 fused_max_bincount_kernel[grid](
168 input_contig,
169 max_tensor,
170 output,
171 n,
172 pre_size,
173 BLOCK_SIZE=BLOCK_SIZE,
174 num_warps=num_warps,
175 )
176 elif is_fp64:
177 fused_max_bincount_weights_fp64_kernel[grid](
178 input_contig,
179 weights_contig,
180 max_tensor,
181 output,
182 n,
183 pre_size,
184 BLOCK_SIZE=BLOCK_SIZE,
185 num_warps=num_warps,
186 )
187 else:
188 fused_max_bincount_weights_fp32_kernel[grid](
189 input_contig,
190 weights_contig,
191 max_tensor,
192 output,
193 n,
194 pre_size,
195 BLOCK_SIZE=BLOCK_SIZE,
196 num_warps=num_warps,
197 )
199 max_val = int(max_tensor.item())
200 needed_size = max(max_val + 1, minlength)
202 if needed_size <= pre_size:
203 return output[:needed_size]
205 output = torch.zeros(needed_size, dtype=compute_dtype, device=input_contig.device)
206 if weights_contig is None:
207 bincount_kernel[grid](
208 input_contig,
209 output,
210 n,
211 BLOCK_SIZE=BLOCK_SIZE,
212 num_warps=num_warps,
213 )
214 elif is_fp64:
215 bincount_weights_fp64_kernel[grid](
216 input_contig,
217 weights_contig,
218 output,
219 n,
220 BLOCK_SIZE=BLOCK_SIZE,
221 num_warps=num_warps,
222 )
223 else:
224 bincount_weights_fp32_kernel[grid](
225 input_contig,
226 weights_contig,
227 output,
228 n,
229 BLOCK_SIZE=BLOCK_SIZE,
230 num_warps=num_warps,
231 )
232 return output
235def bincount(input, weights=None, minlength=0):
236 logger.debug("GEMS BINCOUNT")
238 assert input.dim() == 1, "input must be a 1-D tensor"
239 assert minlength >= 0, "minlength must be non-negative"
241 if weights is not None:
242 assert weights.shape == input.shape, "weights must have the same shape as input"
244 n = input.numel()
246 if n == 0:
247 if weights is not None:
248 return torch.zeros(minlength, dtype=weights.dtype, device=input.device)
249 return torch.zeros(minlength, dtype=torch.int64, device=input.device)
251 input_contig = input.contiguous()
252 weights_contig = weights.contiguous() if weights is not None else None
254 BLOCK_SIZE, num_warps = _select_params(n)
255 grid = (triton.cdiv(n, BLOCK_SIZE),)
257 pre_size = _estimate_output_size(n, minlength)
259 out_dtype = weights.dtype if weights is not None else torch.int64
261 output = _fused_bincount_launch(
262 input_contig,
263 weights_contig,
264 n,
265 pre_size,
266 minlength,
267 out_dtype,
268 grid,
269 BLOCK_SIZE,
270 num_warps,
271 )
273 if (
274 weights is not None
275 and weights.dtype != torch.float64
276 and weights.dtype != torch.float32
277 ):
278 output = output.to(dtype=weights.dtype)
280 return output