Coverage for src/flag_gems/runtime/backend/_mthreads/ops/repeat_interleave.py: 0%

225 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import libentry 

9from flag_gems.utils import triton_lang_extension as tle 

10from flag_gems.utils.pointwise_dynamic import pointwise_dynamic 

11from flag_gems.utils.shape_utils import c_contiguous_stride 

12from flag_gems.utils.tensor_wrapper import StridedBuffer 

13 

14logger = logging.getLogger( 

15 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}' 

16) 

17 

18 

19@pointwise_dynamic(num_inputs=1, promotion_methods=[(0, "DEFAULT")]) 

20@triton.jit 

21def copy_func(x): 

22 return x 

23 

24 

25def repeat_interleave_self_int(inp, repeats, dim=None, *, output_size=None): 

26 logger.debug("GEMS_MTHREADS REPEAT_INTERLEAVE_SELF_INT") 

27 if dim is None: 

28 inp = inp.flatten() 

29 dim = 0 

30 else: 

31 if (dim < -inp.ndim) or (dim >= inp.ndim): 

32 raise IndexError( 

33 "Dimension out of range (expected to be in range of [{}, {}], but got {})".format( 

34 -inp.ndim, inp.ndim - 1, dim 

35 ) 

36 ) 

37 inp_shape = list(inp.shape) 

38 inp_stride = list(inp.stride()) 

39 output_shape = list(inp.shape) 

40 

41 if dim < 0: 

42 dim = dim + len(inp_shape) 

43 

44 output_shape[dim] *= repeats 

45 

46 if output_size is not None and output_size != output_shape[dim]: 

47 raise RuntimeError( 

48 "repeat_interleave: Invalid output_size, expected {} but got {}".format( 

49 output_shape[dim], output_size 

50 ) 

51 ) 

52 

53 output = torch.empty(output_shape, dtype=inp.dtype, device=inp.device) 

54 

55 if repeats == 0: 

56 return output 

57 

58 in_view_stride = inp_stride[: dim + 1] + [0] + inp_stride[dim + 1 :] 

59 out_view_shape = inp_shape[: dim + 1] + [repeats] + inp_shape[dim + 1 :] 

60 out_view_stride = c_contiguous_stride(out_view_shape) 

61 

62 in_view = StridedBuffer(inp, out_view_shape, in_view_stride) 

63 out_view = StridedBuffer(output, out_view_shape, out_view_stride) 

64 ndim = len(out_view_shape) 

65 copy_func.instantiate(ndim)(in_view, out0=out_view) 

66 return output 

67 

68 

69@triton.jit 

70def repeat_interleave_tensor_kernel( 

71 repeats_ptr, cumsum_ptr, out_ptr, size, BLOCK_SIZE: tl.constexpr 

72): 

73 pid = tle.program_id(0) 

74 mask = pid < size 

75 cumsum = tl.load(cumsum_ptr + pid, mask, other=0) 

76 repeats = tl.load(repeats_ptr + pid, mask, other=0) 

77 out_offset = cumsum - repeats 

78 

79 tl.device_assert(repeats >= 0, "repeats can not be negative") 

80 

81 out_ptr += out_offset 

82 for start_k in range(0, repeats, BLOCK_SIZE): 

83 offsets_k = start_k + tl.arange(0, BLOCK_SIZE) 

84 mask_k = offsets_k < repeats 

85 tl.store(out_ptr + offsets_k, pid, mask=mask_k) 

86 

87 

88def repeat_interleave_tensor(repeats, *, output_size=None): 

89 logger.debug("GEMS_MTHREADS REPEAT_INTERLEAVE_TENSOR") 

90 

91 assert repeats.ndim == 1, "repeat_interleave only accept 1D vector as repeat" 

92 

93 cumsum = repeats.cumsum(axis=0) 

94 result_size = cumsum[-1].item() 

95 

96 assert result_size >= 0, "repeats can not be negative" 

97 

98 out = torch.empty((result_size,), dtype=repeats.dtype, device=repeats.device) 

99 size = repeats.size(0) 

100 

101 grid = (size,) 

102 BLOCK_SIZE = 32 

103 with torch_device_fn.device(repeats.device): 

104 repeat_interleave_tensor_kernel[grid]( 

105 repeats, 

106 cumsum, 

107 out, 

108 size, 

109 BLOCK_SIZE=BLOCK_SIZE, 

110 num_warps=1, 

111 ) 

112 return out 

113 

114 

115@libentry() 

116@triton.jit 

117def fused_repeat_interleave_dim0_kernel( 

118 inp_ptr, 

119 out_ptr, 

120 cumsum_ptr, 

121 num_input_rows, 

122 row_size, 

123 BLOCK_SIZE: tl.constexpr, 

124): 

125 """Fused kernel for repeat_interleave with dim=0. 

126 Each program handles one input row and copies to all its repeated output positions. 

127 """ 

128 pid = tle.program_id(0) 

129 

130 if pid >= num_input_rows: 

131 return 

132 

133 # Get output row range for this input row 

134 row_idx_mask = pid > 0 

135 start_row_idx = tl.load(cumsum_ptr + pid - 1, mask=row_idx_mask, other=0) 

136 end_row_idx = tl.load(cumsum_ptr + pid) 

137 

138 num_of_rows = end_row_idx - start_row_idx 

139 if num_of_rows == 0: 

140 return 

141 

142 # Calculate input row offset 

143 inp_row_offset = pid * row_size 

144 

145 # Process columns in blocks 

146 for col_block in range(0, tl.cdiv(row_size, BLOCK_SIZE)): 

147 col_offsets = col_block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

148 col_mask = col_offsets < row_size 

149 

150 # Load from input 

151 cur_inp = tl.load( 

152 inp_ptr + inp_row_offset + col_offsets, mask=col_mask, other=0.0 

153 ) 

154 

155 # Store to each output row 

156 for cur_row in range(0, num_of_rows): 

157 output_row_index = start_row_idx + cur_row 

158 output_row_offsets = output_row_index * row_size + col_offsets 

159 tl.store(out_ptr + output_row_offsets, cur_inp, mask=col_mask) 

160 

161 

162@libentry() 

163@triton.jit 

164def fused_repeat_interleave_output_centric_kernel( 

165 inp_ptr, 

166 out_ptr, 

167 cumsum_ptr, 

168 num_input_rows, 

169 num_output_rows, 

170 row_size, 

171 BLOCK_SIZE: tl.constexpr, 

172): 

173 """Output-centric kernel for repeat_interleave with dim=0. 

174 Uses 2D grid: (num_output_rows, num_col_chunks). 

175 Uses binary search to find input row. 

176 """ 

177 out_row_idx = tle.program_id(0) 

178 col_chunk_idx = tle.program_id(1) 

179 

180 if out_row_idx >= num_output_rows: 

181 return 

182 

183 # Binary search to find input row index 

184 # Find the smallest i such that cumsum[i] > out_row_idx 

185 low = 0 

186 high = num_input_rows 

187 while low < high: 

188 mid = (low + high) // 2 

189 cumsum_mid = tl.load(cumsum_ptr + mid) 

190 if cumsum_mid <= out_row_idx: 

191 low = mid + 1 

192 else: 

193 high = mid 

194 

195 inp_row_idx = low 

196 

197 # Calculate column offsets for this chunk 

198 col_offset = col_chunk_idx * BLOCK_SIZE 

199 col_offsets = col_offset + tl.arange(0, BLOCK_SIZE) 

200 col_mask = col_offsets < row_size 

201 

202 # Load from input 

203 inp_row_offset = inp_row_idx * row_size 

204 cur_inp = tl.load(inp_ptr + inp_row_offset + col_offsets, mask=col_mask, other=0.0) 

205 

206 # Store to output 

207 out_row_offset = out_row_idx * row_size 

208 tl.store(out_ptr + out_row_offset + col_offsets, cur_inp, mask=col_mask) 

209 

210 

211@libentry() 

212@triton.jit 

213def fused_repeat_interleave_1d_bsearch_kernel( 

214 inp_ptr, 

215 out_ptr, 

216 cumsum_ptr, 

217 num_input_rows, 

218 num_output_rows, 

219 row_size, 

220 BLOCK_SIZE: tl.constexpr, 

221): 

222 """1D output-centric kernel with binary search. 

223 Each program handles one complete output row. 

224 Better for large row sizes. 

225 """ 

226 out_row_idx = tle.program_id(0) 

227 

228 if out_row_idx >= num_output_rows: 

229 return 

230 

231 # Binary search to find input row index 

232 low = 0 

233 high = num_input_rows 

234 while low < high: 

235 mid = (low + high) // 2 

236 cumsum_mid = tl.load(cumsum_ptr + mid) 

237 if cumsum_mid <= out_row_idx: 

238 low = mid + 1 

239 else: 

240 high = mid 

241 

242 inp_row_idx = low 

243 

244 # Calculate row offsets 

245 inp_row_offset = inp_row_idx * row_size 

246 out_row_offset = out_row_idx * row_size 

247 

248 # Process all columns in blocks 

249 for col_offset in range(0, row_size, BLOCK_SIZE): 

250 col_offsets = col_offset + tl.arange(0, BLOCK_SIZE) 

251 col_mask = col_offsets < row_size 

252 

253 cur_inp = tl.load( 

254 inp_ptr + inp_row_offset + col_offsets, mask=col_mask, other=0.0 

255 ) 

256 tl.store(out_ptr + out_row_offset + col_offsets, cur_inp, mask=col_mask) 

257 

258 

259@libentry() 

260@triton.jit 

261def fused_repeat_interleave_with_indices_kernel( 

262 inp_ptr, 

263 out_ptr, 

264 index_ptr, 

265 num_output_rows, 

266 row_size, 

267 BLOCK_SIZE: tl.constexpr, 

268): 

269 """Output-centric kernel using precomputed index mapping. 

270 Uses 2D grid: (num_output_rows, num_col_chunks). 

271 """ 

272 out_row_idx = tle.program_id(0) 

273 col_chunk_idx = tle.program_id(1) 

274 

275 if out_row_idx >= num_output_rows: 

276 return 

277 

278 # Load precomputed input row index 

279 inp_row_idx = tl.load(index_ptr + out_row_idx) 

280 

281 # Calculate column offsets for this chunk 

282 col_offset = col_chunk_idx * BLOCK_SIZE 

283 col_offsets = col_offset + tl.arange(0, BLOCK_SIZE) 

284 col_mask = col_offsets < row_size 

285 

286 # Load from input 

287 inp_row_offset = inp_row_idx * row_size 

288 cur_inp = tl.load(inp_ptr + inp_row_offset + col_offsets, mask=col_mask, other=0.0) 

289 

290 # Store to output 

291 out_row_offset = out_row_idx * row_size 

292 tl.store(out_ptr + out_row_offset + col_offsets, cur_inp, mask=col_mask) 

293 

294 

295@libentry() 

296@triton.jit 

297def fused_repeat_interleave_large_row_kernel( 

298 inp_ptr, 

299 out_ptr, 

300 index_ptr, 

301 num_output_rows, 

302 row_size, 

303 BLOCK_SIZE: tl.constexpr, 

304): 

305 """Optimized kernel for large row sizes. 

306 Each program handles one output row and processes all columns. 

307 """ 

308 out_row_idx = tle.program_id(0) 

309 

310 if out_row_idx >= num_output_rows: 

311 return 

312 

313 # Load precomputed input row index 

314 inp_row_idx = tl.load(index_ptr + out_row_idx) 

315 

316 # Calculate row offsets 

317 inp_row_offset = inp_row_idx * row_size 

318 out_row_offset = out_row_idx * row_size 

319 

320 # Process all columns in blocks 

321 for col_offset in range(0, row_size, BLOCK_SIZE): 

322 col_offsets = col_offset + tl.arange(0, BLOCK_SIZE) 

323 col_mask = col_offsets < row_size 

324 

325 # Load from input and store to output 

326 cur_inp = tl.load( 

327 inp_ptr + inp_row_offset + col_offsets, mask=col_mask, other=0.0 

328 ) 

329 tl.store(out_ptr + out_row_offset + col_offsets, cur_inp, mask=col_mask) 

330 

331 

332def fused_repeat_interleave_dim0(inp, repeats, dim): 

333 """Fused repeat_interleave for dim=0 case. 

334 Works with any tensor dimension, handles dim=0 efficiently. 

335 """ 

336 logger.debug("GEMS_MTHREADS FUSED_REPEAT_INTERLEAVE_DIM0") 

337 

338 assert repeats.ndim == 1, "repeat_interleave only accept 1D vector as repeat" 

339 

340 # Compute cumsum of repeats 

341 cumsum = repeats.cumsum(axis=0) 

342 total_output_rows = cumsum[-1].item() 

343 

344 if total_output_rows == 0: 

345 out_shape = list(inp.shape) 

346 out_shape[dim] = 0 

347 return torch.empty(out_shape, dtype=inp.dtype, device=inp.device) 

348 

349 # Setup output tensor 

350 out_shape = list(inp.shape) 

351 out_shape[dim] = total_output_rows 

352 out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device) 

353 

354 # Flatten non-dim dimensions for easier indexing 

355 num_input_rows = inp.shape[dim] 

356 row_size = inp.numel() // num_input_rows 

357 

358 # Make input contiguous for efficient access 

359 inp_contig = inp.contiguous() 

360 

361 # Strategy selection: 

362 # 1. Small tensors: input-centric kernel 

363 # 2. Medium row sizes: output-centric 2D grid with binary search 

364 # 3. Large row sizes: output-centric 1D grid with binary search 

365 

366 if row_size < 512 and total_output_rows < 512: 

367 # Small tensor: use input-centric kernel 

368 BLOCK_SIZE = min(triton.next_power_of_2(row_size), 4096) 

369 

370 if BLOCK_SIZE <= 256: 

371 num_warps = 2 

372 elif BLOCK_SIZE <= 512: 

373 num_warps = 4 

374 else: 

375 num_warps = 8 

376 

377 grid = (num_input_rows,) 

378 

379 with torch_device_fn.device(inp.device): 

380 fused_repeat_interleave_dim0_kernel[grid]( 

381 inp_contig, 

382 out, 

383 cumsum, 

384 num_input_rows, 

385 row_size, 

386 BLOCK_SIZE=BLOCK_SIZE, 

387 num_warps=num_warps, 

388 ) 

389 elif row_size >= 16384: 

390 # Large row size: use 1D grid with binary search 

391 # This reduces total number of programs and amortizes binary search cost 

392 BLOCK_SIZE = 2048 

393 num_warps = 16 

394 

395 grid = (total_output_rows,) 

396 

397 with torch_device_fn.device(inp.device): 

398 fused_repeat_interleave_1d_bsearch_kernel[grid]( 

399 inp_contig, 

400 out, 

401 cumsum, 

402 num_input_rows, 

403 total_output_rows, 

404 row_size, 

405 BLOCK_SIZE=BLOCK_SIZE, 

406 num_warps=num_warps, 

407 ) 

408 else: 

409 # Medium row size: use 2D grid with binary search 

410 BLOCK_SIZE = min(triton.next_power_of_2(row_size), 1024) 

411 num_col_chunks = triton.cdiv(row_size, BLOCK_SIZE) 

412 

413 if BLOCK_SIZE <= 256: 

414 num_warps = 2 

415 elif BLOCK_SIZE <= 512: 

416 num_warps = 4 

417 else: 

418 num_warps = 8 

419 

420 grid = (total_output_rows, num_col_chunks) 

421 

422 with torch_device_fn.device(inp.device): 

423 fused_repeat_interleave_output_centric_kernel[grid]( 

424 inp_contig, 

425 out, 

426 cumsum, 

427 num_input_rows, 

428 total_output_rows, 

429 row_size, 

430 BLOCK_SIZE=BLOCK_SIZE, 

431 num_warps=num_warps, 

432 ) 

433 

434 return out 

435 

436 

437def repeat_interleave_self_tensor(inp, repeats, dim=None, *, output_size=None): 

438 logger.debug("GEMS_MTHREADS REPEAT_INTERLEAVE_SELF_TENSOR") 

439 

440 if dim is None: 

441 inp = inp.flatten() 

442 dim = 0 

443 else: 

444 if (dim < -inp.ndim) or (dim >= inp.ndim): 

445 raise IndexError( 

446 "Dimension out of range (expected to be in range of [{}, {}], but got {})".format( 

447 -inp.ndim, inp.ndim - 1, dim 

448 ) 

449 ) 

450 

451 if repeats.ndim == 0 or (repeats.ndim == 1 and repeats.size(0) == 1): 

452 return repeat_interleave_self_int( 

453 inp, repeats.item(), dim=dim, output_size=output_size 

454 ) 

455 elif repeats.ndim > 1: 

456 raise RuntimeError("repeats must be 0-dim or 1-dim tensor") 

457 

458 inp_shape = list(inp.shape) 

459 if dim < 0: 

460 dim = dim + len(inp_shape) 

461 

462 if repeats.size(0) != inp_shape[dim]: 

463 raise RuntimeError( 

464 "repeats must have the same size as input along dim, but got \ 

465 repeats.size(0) = {} and input.size({}) = {}".format( 

466 repeats.size(0), dim, inp_shape[dim] 

467 ) 

468 ) 

469 

470 # Use fused kernel for dim=0 

471 if dim == 0: 

472 return fused_repeat_interleave_dim0(inp, repeats, dim) 

473 

474 # For other dimensions, use the fallback implementation 

475 indices = repeat_interleave_tensor(repeats) 

476 res = torch.index_select(inp, dim, indices) 

477 

478 return res