Coverage for src/flag_gems/runtime/backend/_mthreads/ops/tile.py: 0%
174 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import triton_lang_extension as tle
9from flag_gems.utils.libentry import libentry
11logger = logging.getLogger(__name__)
14@libentry()
15@triton.autotune(
16 configs=[
17 triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4),
18 triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_warps=4),
19 triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=4),
20 triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=8),
21 triton.Config({"BLOCK_M": 128, "BLOCK_N": 32}, num_warps=4),
22 triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=4),
23 triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_warps=8),
24 triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=8),
25 ],
26 key=["out_shape0", "out_shape1"],
27)
28@triton.jit
29def tile_kernel_2d(
30 inp_ptr,
31 out_ptr,
32 inp_stride0,
33 inp_stride1,
34 out_stride0,
35 out_stride1,
36 inp_shape0,
37 inp_shape1,
38 out_shape0,
39 out_shape1,
40 BLOCK_M: tl.constexpr,
41 BLOCK_N: tl.constexpr,
42):
43 pid_m = tle.program_id(0)
44 pid_n = tle.program_id(1)
46 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
47 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
49 mask_m = offs_m < out_shape0
50 mask_n = offs_n < out_shape1
51 mask = mask_m[:, None] & mask_n[None, :]
53 # Map output indices to input indices using modulo
54 inp_offs_m = offs_m % inp_shape0
55 inp_offs_n = offs_n % inp_shape1
57 # Load from input
58 inp_ptrs = (
59 inp_ptr + inp_offs_m[:, None] * inp_stride0 + inp_offs_n[None, :] * inp_stride1
60 )
61 data = tl.load(inp_ptrs, mask=mask, other=0.0)
63 # Store to output
64 out_ptrs = out_ptr + offs_m[:, None] * out_stride0 + offs_n[None, :] * out_stride1
65 tl.store(out_ptrs, data, mask=mask)
68@libentry()
69@triton.autotune(
70 configs=[
71 triton.Config({"BLOCK_SIZE": 256}, num_warps=4),
72 triton.Config({"BLOCK_SIZE": 512}, num_warps=4),
73 triton.Config({"BLOCK_SIZE": 1024}, num_warps=8),
74 triton.Config({"BLOCK_SIZE": 2048}, num_warps=8),
75 ],
76 key=["out_shape0"],
77)
78@triton.jit
79def tile_kernel_1d(
80 inp_ptr,
81 out_ptr,
82 inp_stride0,
83 out_stride0,
84 inp_shape0,
85 out_shape0,
86 BLOCK_SIZE: tl.constexpr,
87):
88 pid = tle.program_id(0)
89 offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
90 mask = offs < out_shape0
92 # Map output indices to input indices
93 inp_offs = offs % inp_shape0
95 # Load and store
96 data = tl.load(inp_ptr + inp_offs * inp_stride0, mask=mask)
97 tl.store(out_ptr + offs * out_stride0, data, mask=mask)
100@libentry()
101@triton.autotune(
102 configs=[
103 triton.Config({"BLOCK_N": 32, "BLOCK_K": 32}, num_warps=4),
104 triton.Config({"BLOCK_N": 64, "BLOCK_K": 32}, num_warps=4),
105 triton.Config({"BLOCK_N": 32, "BLOCK_K": 64}, num_warps=4),
106 triton.Config({"BLOCK_N": 64, "BLOCK_K": 64}, num_warps=8),
107 triton.Config({"BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4),
108 triton.Config({"BLOCK_N": 32, "BLOCK_K": 128}, num_warps=4),
109 ],
110 key=["out_shape1", "out_shape2"],
111)
112@triton.jit
113def tile_kernel_3d(
114 inp_ptr,
115 out_ptr,
116 inp_stride0,
117 inp_stride1,
118 inp_stride2,
119 out_stride0,
120 out_stride1,
121 out_stride2,
122 inp_shape0,
123 inp_shape1,
124 inp_shape2,
125 out_shape0,
126 out_shape1,
127 out_shape2,
128 BLOCK_N: tl.constexpr,
129 BLOCK_K: tl.constexpr,
130):
131 """Process 3D tile: one program handles one (m, n_block, k_block)"""
132 pid_m = tle.program_id(0)
133 pid_nk = tle.program_id(1)
135 num_k_blocks = tl.cdiv(out_shape2, BLOCK_K)
136 pid_n = pid_nk // num_k_blocks
137 pid_k = pid_nk % num_k_blocks
139 m_idx = pid_m
140 if m_idx >= out_shape0:
141 return
143 inp_m = m_idx % inp_shape0
145 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
146 offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
148 mask_n = offs_n < out_shape1
149 mask_k = offs_k < out_shape2
150 mask = mask_n[:, None] & mask_k[None, :]
152 inp_n = offs_n % inp_shape1
153 inp_k = offs_k % inp_shape2
155 inp_ptrs = (
156 inp_ptr
157 + inp_m * inp_stride0
158 + inp_n[:, None] * inp_stride1
159 + inp_k[None, :] * inp_stride2
160 )
161 data = tl.load(inp_ptrs, mask=mask, other=0.0)
163 out_ptrs = (
164 out_ptr
165 + m_idx * out_stride0
166 + offs_n[:, None] * out_stride1
167 + offs_k[None, :] * out_stride2
168 )
169 tl.store(out_ptrs, data, mask=mask)
172@libentry()
173@triton.autotune(
174 configs=[
175 triton.Config({"BLOCK_K": 32, "BLOCK_L": 32}, num_warps=4),
176 triton.Config({"BLOCK_K": 64, "BLOCK_L": 32}, num_warps=4),
177 triton.Config({"BLOCK_K": 32, "BLOCK_L": 64}, num_warps=4),
178 triton.Config({"BLOCK_K": 64, "BLOCK_L": 64}, num_warps=8),
179 triton.Config({"BLOCK_K": 128, "BLOCK_L": 32}, num_warps=4),
180 triton.Config({"BLOCK_K": 32, "BLOCK_L": 128}, num_warps=4),
181 ],
182 key=["out_shape2", "out_shape3"],
183)
184@triton.jit
185def tile_kernel_4d(
186 inp_ptr,
187 out_ptr,
188 inp_stride0,
189 inp_stride1,
190 inp_stride2,
191 inp_stride3,
192 out_stride0,
193 out_stride1,
194 out_stride2,
195 out_stride3,
196 inp_shape0,
197 inp_shape1,
198 inp_shape2,
199 inp_shape3,
200 out_shape0,
201 out_shape1,
202 out_shape2,
203 out_shape3,
204 BLOCK_K: tl.constexpr,
205 BLOCK_L: tl.constexpr,
206):
207 """Process 4D tile: one program handles one (m, n, k_block, l_block)"""
208 pid_mn = tle.program_id(0)
209 pid_kl = tle.program_id(1)
211 num_l_blocks = tl.cdiv(out_shape3, BLOCK_L)
212 pid_k = pid_kl // num_l_blocks
213 pid_l = pid_kl % num_l_blocks
215 # Flatten m, n
216 m_idx = pid_mn // out_shape1
217 n_idx = pid_mn % out_shape1
219 if m_idx >= out_shape0:
220 return
222 inp_m = m_idx % inp_shape0
223 inp_n = n_idx % inp_shape1
225 offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
226 offs_l = pid_l * BLOCK_L + tl.arange(0, BLOCK_L)
228 mask_k = offs_k < out_shape2
229 mask_l = offs_l < out_shape3
230 mask = mask_k[:, None] & mask_l[None, :]
232 inp_k = offs_k % inp_shape2
233 inp_l = offs_l % inp_shape3
235 inp_ptrs = (
236 inp_ptr
237 + inp_m * inp_stride0
238 + inp_n * inp_stride1
239 + inp_k[:, None] * inp_stride2
240 + inp_l[None, :] * inp_stride3
241 )
242 data = tl.load(inp_ptrs, mask=mask, other=0.0)
244 out_ptrs = (
245 out_ptr
246 + m_idx * out_stride0
247 + n_idx * out_stride1
248 + offs_k[:, None] * out_stride2
249 + offs_l[None, :] * out_stride3
250 )
251 tl.store(out_ptrs, data, mask=mask)
254@libentry()
255@triton.jit
256def tile_kernel_nd_flat(
257 inp_ptr,
258 out_ptr,
259 num_tasks,
260 inp_shape0,
261 inp_shape1,
262 inp_shape2,
263 inp_shape3,
264 inp_shape4,
265 out_shape0,
266 out_shape1,
267 out_shape2,
268 out_shape3,
269 out_shape4,
270 inp_stride0,
271 inp_stride1,
272 inp_stride2,
273 inp_stride3,
274 inp_stride4,
275 out_stride0,
276 out_stride1,
277 out_stride2,
278 out_stride3,
279 out_stride4,
280 rank: tl.constexpr,
281 BLOCK_SIZE: tl.constexpr,
282):
283 """Generic N-D tile kernel (up to 5D) using flat indexing with modulo"""
284 pid = tle.program_id(0)
285 num_ctas = tle.num_programs(0)
287 for idx in range(pid * BLOCK_SIZE, num_tasks, num_ctas * BLOCK_SIZE):
288 offs = idx + tl.arange(0, BLOCK_SIZE)
289 mask = offs < num_tasks
291 remaining = offs
293 # Unroll for up to 5D
294 if rank >= 5:
295 out_idx4 = remaining % out_shape4
296 inp_idx4 = out_idx4 % inp_shape4
297 remaining = remaining // out_shape4
298 else:
299 out_idx4 = tl.zeros_like(offs)
300 inp_idx4 = tl.zeros_like(offs)
302 if rank >= 4:
303 out_idx3 = remaining % out_shape3
304 inp_idx3 = out_idx3 % inp_shape3
305 remaining = remaining // out_shape3
306 else:
307 out_idx3 = tl.zeros_like(offs)
308 inp_idx3 = tl.zeros_like(offs)
310 if rank >= 3:
311 out_idx2 = remaining % out_shape2
312 inp_idx2 = out_idx2 % inp_shape2
313 remaining = remaining // out_shape2
314 else:
315 out_idx2 = tl.zeros_like(offs)
316 inp_idx2 = tl.zeros_like(offs)
318 if rank >= 2:
319 out_idx1 = remaining % out_shape1
320 inp_idx1 = out_idx1 % inp_shape1
321 remaining = remaining // out_shape1
322 else:
323 out_idx1 = tl.zeros_like(offs)
324 inp_idx1 = tl.zeros_like(offs)
326 out_idx0 = remaining
327 inp_idx0 = out_idx0 % inp_shape0
329 inp_offset = (
330 inp_idx0 * inp_stride0
331 + inp_idx1 * inp_stride1
332 + inp_idx2 * inp_stride2
333 + inp_idx3 * inp_stride3
334 + inp_idx4 * inp_stride4
335 )
336 out_offset = (
337 out_idx0 * out_stride0
338 + out_idx1 * out_stride1
339 + out_idx2 * out_stride2
340 + out_idx3 * out_stride3
341 + out_idx4 * out_stride4
342 )
344 data = tl.load(inp_ptr + inp_offset, mask=mask)
345 tl.store(out_ptr + out_offset, data, mask=mask)
348def tile(inp: torch.Tensor, dims) -> torch.Tensor:
349 logger.debug("GEMS TILE")
351 in0_rank = inp.dim()
352 dims_rank = len(dims)
353 in0_shape = list(inp.shape)
354 dims_shape = list(dims)
356 # Normalize shapes
357 if dims_rank < in0_rank:
358 diff = in0_rank - dims_rank
359 dims_shape = [1] * diff + dims_shape
360 elif dims_rank > in0_rank:
361 diff = dims_rank - in0_rank
362 in0_shape = [1] * diff + in0_shape
364 # Check for empty and compute output shape
365 is_empty = False
366 out_shape = []
367 for i in range(len(in0_shape)):
368 assert (
369 dims_shape[i] >= 0
370 ), f"the number of repetitions per dimension out of range (expected to >= 0) but got {dims_shape[i]}"
371 if dims_shape[i] == 0:
372 is_empty = True
373 out_shape.append(in0_shape[i] * dims_shape[i])
375 out = torch.empty(out_shape, device=inp.device, dtype=inp.dtype)
377 if is_empty:
378 return out
380 inp = inp.reshape(in0_shape)
381 rank = len(out_shape)
382 num_tasks = out.numel()
384 # Get strides (handle 0-sized dimensions)
385 inp_strides = list(inp.stride())
386 out_strides = list(out.stride())
388 with torch_device_fn.device(inp.device.index):
389 if rank == 1:
390 # 1D case with autotune
391 grid = lambda META: (triton.cdiv(out_shape[0], META["BLOCK_SIZE"]),)
392 tile_kernel_1d[grid](
393 inp,
394 out,
395 inp_strides[0] if inp_strides[0] != 0 else 1,
396 out_strides[0] if out_strides[0] != 0 else 1,
397 in0_shape[0],
398 out_shape[0],
399 )
400 elif rank == 2:
401 # 2D case - use 2D blocking with autotune
402 grid = lambda META: (
403 triton.cdiv(out_shape[0], META["BLOCK_M"]),
404 triton.cdiv(out_shape[1], META["BLOCK_N"]),
405 )
406 tile_kernel_2d[grid](
407 inp,
408 out,
409 inp_strides[0],
410 inp_strides[1],
411 out_strides[0],
412 out_strides[1],
413 in0_shape[0],
414 in0_shape[1],
415 out_shape[0],
416 out_shape[1],
417 )
418 elif rank == 3:
419 # 3D case
420 grid = lambda META: (
421 out_shape[0],
422 triton.cdiv(out_shape[1], META["BLOCK_N"])
423 * triton.cdiv(out_shape[2], META["BLOCK_K"]),
424 )
425 tile_kernel_3d[grid](
426 inp,
427 out,
428 inp_strides[0],
429 inp_strides[1],
430 inp_strides[2],
431 out_strides[0],
432 out_strides[1],
433 out_strides[2],
434 in0_shape[0],
435 in0_shape[1],
436 in0_shape[2],
437 out_shape[0],
438 out_shape[1],
439 out_shape[2],
440 )
441 elif rank == 4:
442 # 4D case
443 num_mn = out_shape[0] * out_shape[1]
444 grid = lambda META: (
445 num_mn,
446 triton.cdiv(out_shape[2], META["BLOCK_K"])
447 * triton.cdiv(out_shape[3], META["BLOCK_L"]),
448 )
449 tile_kernel_4d[grid](
450 inp,
451 out,
452 inp_strides[0],
453 inp_strides[1],
454 inp_strides[2],
455 inp_strides[3],
456 out_strides[0],
457 out_strides[1],
458 out_strides[2],
459 out_strides[3],
460 in0_shape[0],
461 in0_shape[1],
462 in0_shape[2],
463 in0_shape[3],
464 out_shape[0],
465 out_shape[1],
466 out_shape[2],
467 out_shape[3],
468 )
469 else:
470 # 5D+ case - use generic kernel
471 BLOCK_SIZE = 1024
472 grid = (min(65535, triton.cdiv(num_tasks, BLOCK_SIZE)),)
474 # Pad shapes and strides to 5D
475 while len(in0_shape) < 5:
476 in0_shape = [1] + in0_shape
477 out_shape = [1] + out_shape
478 inp_strides = [0] + inp_strides
479 out_strides = [0] + out_strides
481 tile_kernel_nd_flat[grid](
482 inp,
483 out,
484 num_tasks,
485 in0_shape[0],
486 in0_shape[1],
487 in0_shape[2],
488 in0_shape[3],
489 in0_shape[4],
490 out_shape[0],
491 out_shape[1],
492 out_shape[2],
493 out_shape[3],
494 out_shape[4],
495 inp_strides[0],
496 inp_strides[1],
497 inp_strides[2],
498 inp_strides[3],
499 inp_strides[4],
500 out_strides[0],
501 out_strides[1],
502 out_strides[2],
503 out_strides[3],
504 out_strides[4],
505 rank=rank,
506 BLOCK_SIZE=BLOCK_SIZE,
507 )
509 return out