Coverage for src/flag_gems/runtime/backend/_mthreads/ops/repeat.py: 0%
176 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import triton_lang_extension as tle
9from flag_gems.utils.libentry import libentry
11logger = logging.getLogger(__name__)
14@libentry()
15@triton.autotune(
16 configs=[
17 triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4),
18 triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_warps=4),
19 triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=4),
20 triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=8),
21 triton.Config({"BLOCK_M": 128, "BLOCK_N": 32}, num_warps=4),
22 triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=4),
23 triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_warps=8),
24 triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=8),
25 ],
26 key=["out_shape0", "out_shape1"],
27)
28@triton.jit
29def repeat_kernel_2d(
30 inp_ptr,
31 out_ptr,
32 inp_stride0,
33 inp_stride1,
34 out_stride0,
35 out_stride1,
36 inp_shape0,
37 inp_shape1,
38 out_shape0,
39 out_shape1,
40 BLOCK_M: tl.constexpr,
41 BLOCK_N: tl.constexpr,
42):
43 pid_m = tle.program_id(0)
44 pid_n = tle.program_id(1)
46 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
47 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
49 mask_m = offs_m < out_shape0
50 mask_n = offs_n < out_shape1
51 mask = mask_m[:, None] & mask_n[None, :]
53 # Map output indices to input indices using modulo
54 inp_offs_m = offs_m % inp_shape0
55 inp_offs_n = offs_n % inp_shape1
57 # Load from input
58 inp_ptrs = (
59 inp_ptr + inp_offs_m[:, None] * inp_stride0 + inp_offs_n[None, :] * inp_stride1
60 )
61 data = tl.load(inp_ptrs, mask=mask, other=0.0)
63 # Store to output
64 out_ptrs = out_ptr + offs_m[:, None] * out_stride0 + offs_n[None, :] * out_stride1
65 tl.store(out_ptrs, data, mask=mask)
68@libentry()
69@triton.autotune(
70 configs=[
71 triton.Config({"BLOCK_SIZE": 256}, num_warps=4),
72 triton.Config({"BLOCK_SIZE": 512}, num_warps=4),
73 triton.Config({"BLOCK_SIZE": 1024}, num_warps=8),
74 triton.Config({"BLOCK_SIZE": 2048}, num_warps=8),
75 ],
76 key=["out_shape0"],
77)
78@triton.jit
79def repeat_kernel_1d(
80 inp_ptr,
81 out_ptr,
82 inp_stride0,
83 out_stride0,
84 inp_shape0,
85 out_shape0,
86 BLOCK_SIZE: tl.constexpr,
87):
88 pid = tle.program_id(0)
89 offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
90 mask = offs < out_shape0
92 # Map output indices to input indices
93 inp_offs = offs % inp_shape0
95 # Load and store
96 data = tl.load(inp_ptr + inp_offs * inp_stride0, mask=mask)
97 tl.store(out_ptr + offs * out_stride0, data, mask=mask)
100@libentry()
101@triton.autotune(
102 configs=[
103 triton.Config({"BLOCK_N": 32, "BLOCK_K": 32}, num_warps=4),
104 triton.Config({"BLOCK_N": 64, "BLOCK_K": 32}, num_warps=4),
105 triton.Config({"BLOCK_N": 32, "BLOCK_K": 64}, num_warps=4),
106 triton.Config({"BLOCK_N": 64, "BLOCK_K": 64}, num_warps=8),
107 triton.Config({"BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4),
108 triton.Config({"BLOCK_N": 32, "BLOCK_K": 128}, num_warps=4),
109 ],
110 key=["out_shape1", "out_shape2"],
111)
112@triton.jit
113def repeat_kernel_3d(
114 inp_ptr,
115 out_ptr,
116 inp_stride0,
117 inp_stride1,
118 inp_stride2,
119 out_stride0,
120 out_stride1,
121 out_stride2,
122 inp_shape0,
123 inp_shape1,
124 inp_shape2,
125 out_shape0,
126 out_shape1,
127 out_shape2,
128 BLOCK_N: tl.constexpr,
129 BLOCK_K: tl.constexpr,
130):
131 """Process 3D repeat: one program handles one (m, n_block, k_block)"""
132 pid_m = tle.program_id(0)
133 pid_nk = tle.program_id(1)
135 num_k_blocks = tl.cdiv(out_shape2, BLOCK_K)
136 pid_n = pid_nk // num_k_blocks
137 pid_k = pid_nk % num_k_blocks
139 m_idx = pid_m
140 if m_idx >= out_shape0:
141 return
143 inp_m = m_idx % inp_shape0
145 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
146 offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
148 mask_n = offs_n < out_shape1
149 mask_k = offs_k < out_shape2
150 mask = mask_n[:, None] & mask_k[None, :]
152 inp_n = offs_n % inp_shape1
153 inp_k = offs_k % inp_shape2
155 inp_ptrs = (
156 inp_ptr
157 + inp_m * inp_stride0
158 + inp_n[:, None] * inp_stride1
159 + inp_k[None, :] * inp_stride2
160 )
161 data = tl.load(inp_ptrs, mask=mask, other=0.0)
163 out_ptrs = (
164 out_ptr
165 + m_idx * out_stride0
166 + offs_n[:, None] * out_stride1
167 + offs_k[None, :] * out_stride2
168 )
169 tl.store(out_ptrs, data, mask=mask)
172@libentry()
173@triton.autotune(
174 configs=[
175 triton.Config({"BLOCK_K": 32, "BLOCK_L": 32}, num_warps=4),
176 triton.Config({"BLOCK_K": 64, "BLOCK_L": 32}, num_warps=4),
177 triton.Config({"BLOCK_K": 32, "BLOCK_L": 64}, num_warps=4),
178 triton.Config({"BLOCK_K": 64, "BLOCK_L": 64}, num_warps=8),
179 triton.Config({"BLOCK_K": 128, "BLOCK_L": 32}, num_warps=4),
180 triton.Config({"BLOCK_K": 32, "BLOCK_L": 128}, num_warps=4),
181 triton.Config({"BLOCK_K": 128, "BLOCK_L": 64}, num_warps=8),
182 triton.Config({"BLOCK_K": 64, "BLOCK_L": 128}, num_warps=8),
183 ],
184 key=["out_shape2", "out_shape3"],
185)
186@triton.jit
187def repeat_kernel_4d(
188 inp_ptr,
189 out_ptr,
190 inp_stride0,
191 inp_stride1,
192 inp_stride2,
193 inp_stride3,
194 out_stride0,
195 out_stride1,
196 out_stride2,
197 out_stride3,
198 inp_shape0,
199 inp_shape1,
200 inp_shape2,
201 inp_shape3,
202 out_shape0,
203 out_shape1,
204 out_shape2,
205 out_shape3,
206 BLOCK_K: tl.constexpr,
207 BLOCK_L: tl.constexpr,
208):
209 """Process 4D repeat: one program handles one (m, n, k_block, l_block)"""
210 pid_mn = tle.program_id(0)
211 pid_kl = tle.program_id(1)
213 num_l_blocks = tl.cdiv(out_shape3, BLOCK_L)
214 pid_k = pid_kl // num_l_blocks
215 pid_l = pid_kl % num_l_blocks
217 # Flatten m, n
218 m_idx = pid_mn // out_shape1
219 n_idx = pid_mn % out_shape1
221 if m_idx >= out_shape0:
222 return
224 inp_m = m_idx % inp_shape0
225 inp_n = n_idx % inp_shape1
227 offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
228 offs_l = pid_l * BLOCK_L + tl.arange(0, BLOCK_L)
230 mask_k = offs_k < out_shape2
231 mask_l = offs_l < out_shape3
232 mask = mask_k[:, None] & mask_l[None, :]
234 inp_k = offs_k % inp_shape2
235 inp_l = offs_l % inp_shape3
237 inp_ptrs = (
238 inp_ptr
239 + inp_m * inp_stride0
240 + inp_n * inp_stride1
241 + inp_k[:, None] * inp_stride2
242 + inp_l[None, :] * inp_stride3
243 )
244 data = tl.load(inp_ptrs, mask=mask, other=0.0)
246 out_ptrs = (
247 out_ptr
248 + m_idx * out_stride0
249 + n_idx * out_stride1
250 + offs_k[:, None] * out_stride2
251 + offs_l[None, :] * out_stride3
252 )
253 tl.store(out_ptrs, data, mask=mask)
256@libentry()
257@triton.autotune(
258 configs=[
259 triton.Config({"BLOCK_SIZE": 256}, num_warps=4),
260 triton.Config({"BLOCK_SIZE": 512}, num_warps=4),
261 triton.Config({"BLOCK_SIZE": 1024}, num_warps=8),
262 triton.Config({"BLOCK_SIZE": 2048}, num_warps=8),
263 triton.Config({"BLOCK_SIZE": 4096}, num_warps=16),
264 ],
265 key=["num_tasks"],
266)
267@triton.jit
268def repeat_kernel_nd_flat(
269 inp_ptr,
270 out_ptr,
271 num_tasks,
272 inp_shape0,
273 inp_shape1,
274 inp_shape2,
275 inp_shape3,
276 inp_shape4,
277 out_shape0,
278 out_shape1,
279 out_shape2,
280 out_shape3,
281 out_shape4,
282 inp_stride0,
283 inp_stride1,
284 inp_stride2,
285 inp_stride3,
286 inp_stride4,
287 out_stride0,
288 out_stride1,
289 out_stride2,
290 out_stride3,
291 out_stride4,
292 rank: tl.constexpr,
293 BLOCK_SIZE: tl.constexpr,
294):
295 """Generic N-D repeat kernel (up to 5D) using flat indexing with modulo"""
296 pid = tle.program_id(0)
297 num_ctas = tle.num_programs(0)
299 for idx in range(pid * BLOCK_SIZE, num_tasks, num_ctas * BLOCK_SIZE):
300 offs = idx + tl.arange(0, BLOCK_SIZE)
301 mask = offs < num_tasks
303 remaining = offs
305 # Unroll for up to 5D
306 if rank >= 5:
307 out_idx4 = remaining % out_shape4
308 inp_idx4 = out_idx4 % inp_shape4
309 remaining = remaining // out_shape4
310 else:
311 out_idx4 = tl.zeros_like(offs)
312 inp_idx4 = tl.zeros_like(offs)
314 if rank >= 4:
315 out_idx3 = remaining % out_shape3
316 inp_idx3 = out_idx3 % inp_shape3
317 remaining = remaining // out_shape3
318 else:
319 out_idx3 = tl.zeros_like(offs)
320 inp_idx3 = tl.zeros_like(offs)
322 if rank >= 3:
323 out_idx2 = remaining % out_shape2
324 inp_idx2 = out_idx2 % inp_shape2
325 remaining = remaining // out_shape2
326 else:
327 out_idx2 = tl.zeros_like(offs)
328 inp_idx2 = tl.zeros_like(offs)
330 if rank >= 2:
331 out_idx1 = remaining % out_shape1
332 inp_idx1 = out_idx1 % inp_shape1
333 remaining = remaining // out_shape1
334 else:
335 out_idx1 = tl.zeros_like(offs)
336 inp_idx1 = tl.zeros_like(offs)
338 out_idx0 = remaining
339 inp_idx0 = out_idx0 % inp_shape0
341 inp_offset = (
342 inp_idx0 * inp_stride0
343 + inp_idx1 * inp_stride1
344 + inp_idx2 * inp_stride2
345 + inp_idx3 * inp_stride3
346 + inp_idx4 * inp_stride4
347 )
348 out_offset = (
349 out_idx0 * out_stride0
350 + out_idx1 * out_stride1
351 + out_idx2 * out_stride2
352 + out_idx3 * out_stride3
353 + out_idx4 * out_stride4
354 )
356 data = tl.load(inp_ptr + inp_offset, mask=mask)
357 tl.store(out_ptr + out_offset, data, mask=mask)
360def repeat(inp: torch.Tensor, sizes) -> torch.Tensor:
361 logger.debug("GEMS_MTHREADS REPEAT")
363 in0_rank = inp.dim()
364 sizes_rank = len(sizes)
365 in0_shape = list(inp.shape)
366 sizes_shape = list(sizes)
368 # Normalize shapes - for repeat, sizes_rank must be >= in0_rank
369 assert (
370 sizes_rank >= in0_rank
371 ), "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor"
373 if sizes_rank > in0_rank:
374 diff = sizes_rank - in0_rank
375 in0_shape = [1] * diff + in0_shape
377 # Check for empty and compute output shape
378 is_empty = False
379 out_shape = []
380 for i in range(len(in0_shape)):
381 assert (
382 sizes_shape[i] >= 0
383 ), f"the number of repetitions per dimension out of range (expected to >= 0) but got {sizes_shape[i]}"
384 if sizes_shape[i] == 0:
385 is_empty = True
386 out_shape.append(in0_shape[i] * sizes_shape[i])
388 out = torch.empty(out_shape, device=inp.device, dtype=inp.dtype)
390 if is_empty:
391 return out
393 inp = inp.reshape(in0_shape)
394 rank = len(out_shape)
395 num_tasks = out.numel()
397 # Get strides (handle 0-sized dimensions)
398 inp_strides = list(inp.stride())
399 out_strides = list(out.stride())
401 with torch_device_fn.device(inp.device.index):
402 if rank == 1:
403 # 1D case with autotune
404 grid = lambda META: (triton.cdiv(out_shape[0], META["BLOCK_SIZE"]),)
405 repeat_kernel_1d[grid](
406 inp,
407 out,
408 inp_strides[0] if inp_strides[0] != 0 else 1,
409 out_strides[0] if out_strides[0] != 0 else 1,
410 in0_shape[0],
411 out_shape[0],
412 )
413 elif rank == 2:
414 # 2D case - use 2D blocking with autotune
415 grid = lambda META: (
416 triton.cdiv(out_shape[0], META["BLOCK_M"]),
417 triton.cdiv(out_shape[1], META["BLOCK_N"]),
418 )
419 repeat_kernel_2d[grid](
420 inp,
421 out,
422 inp_strides[0],
423 inp_strides[1],
424 out_strides[0],
425 out_strides[1],
426 in0_shape[0],
427 in0_shape[1],
428 out_shape[0],
429 out_shape[1],
430 )
431 elif rank == 3:
432 # 3D case
433 grid = lambda META: (
434 out_shape[0],
435 triton.cdiv(out_shape[1], META["BLOCK_N"])
436 * triton.cdiv(out_shape[2], META["BLOCK_K"]),
437 )
438 repeat_kernel_3d[grid](
439 inp,
440 out,
441 inp_strides[0],
442 inp_strides[1],
443 inp_strides[2],
444 out_strides[0],
445 out_strides[1],
446 out_strides[2],
447 in0_shape[0],
448 in0_shape[1],
449 in0_shape[2],
450 out_shape[0],
451 out_shape[1],
452 out_shape[2],
453 )
454 elif rank == 4:
455 # 4D case - use 2D grid kernel
456 num_mn = out_shape[0] * out_shape[1]
457 grid = lambda META: (
458 num_mn,
459 triton.cdiv(out_shape[2], META["BLOCK_K"])
460 * triton.cdiv(out_shape[3], META["BLOCK_L"]),
461 )
462 repeat_kernel_4d[grid](
463 inp,
464 out,
465 inp_strides[0],
466 inp_strides[1],
467 inp_strides[2],
468 inp_strides[3],
469 out_strides[0],
470 out_strides[1],
471 out_strides[2],
472 out_strides[3],
473 in0_shape[0],
474 in0_shape[1],
475 in0_shape[2],
476 in0_shape[3],
477 out_shape[0],
478 out_shape[1],
479 out_shape[2],
480 out_shape[3],
481 )
482 else:
483 # 5D+ case - use generic kernel with autotune
484 # Pad shapes and strides to 5D
485 in0_shape_padded = list(in0_shape)
486 out_shape_padded = list(out_shape)
487 inp_strides_padded = list(inp_strides)
488 out_strides_padded = list(out_strides)
490 while len(in0_shape_padded) < 5:
491 in0_shape_padded = [1] + in0_shape_padded
492 out_shape_padded = [1] + out_shape_padded
493 inp_strides_padded = [0] + inp_strides_padded
494 out_strides_padded = [0] + out_strides_padded
496 grid = lambda META: (
497 min(65535, triton.cdiv(num_tasks, META["BLOCK_SIZE"])),
498 )
499 repeat_kernel_nd_flat[grid](
500 inp,
501 out,
502 num_tasks,
503 in0_shape_padded[0],
504 in0_shape_padded[1],
505 in0_shape_padded[2],
506 in0_shape_padded[3],
507 in0_shape_padded[4],
508 out_shape_padded[0],
509 out_shape_padded[1],
510 out_shape_padded[2],
511 out_shape_padded[3],
512 out_shape_padded[4],
513 inp_strides_padded[0],
514 inp_strides_padded[1],
515 inp_strides_padded[2],
516 inp_strides_padded[3],
517 inp_strides_padded[4],
518 out_strides_padded[0],
519 out_strides_padded[1],
520 out_strides_padded[2],
521 out_strides_padded[3],
522 out_strides_padded[4],
523 rank=rank,
524 )
526 return out