Coverage for src/flag_gems/runtime/backend/_mthreads/ops/gather.py: 0%
71 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
1import logging
2from typing import Optional
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.ops.gather import gather as default_gather
10from flag_gems.ops.gather import gather_backward as default_gather_backward
11from flag_gems.runtime import torch_device_fn
12from flag_gems.utils import libentry
14logger = logging.getLogger(
15 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}'
16)
18_SUPPORTED_DTYPES = {torch.float16, torch.bfloat16, torch.float32}
21@libentry()
22@triton.heuristics(runtime.get_heuristic_config("gather"))
23@triton.jit
24def _gather_lastdim_kernel(
25 inp_ptr,
26 index_ptr,
27 out_ptr,
28 stride_inp_row,
29 stride_index_row,
30 stride_out_row,
31 dim_stride,
32 M,
33 N,
34 BLOCK_M: tl.constexpr,
35 BLOCK_N: tl.constexpr,
36):
37 pid_m = tl.program_id(0)
38 pid_n = tl.program_id(1)
39 rows = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
40 cols = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)[None, :]
41 rows = rows.to(tl.int64)
42 cols = cols.to(tl.int64)
43 mask = (rows < M) & (cols < N)
45 row_inp = rows * stride_inp_row
46 row_idx = rows * stride_index_row
47 row_out = rows * stride_out_row
49 idx = tl.load(index_ptr + row_idx + cols, mask=mask, other=0).to(tl.int64)
50 gather_ptr = inp_ptr + row_inp + idx * dim_stride
51 values = tl.load(gather_ptr, mask=mask, other=0)
52 tl.store(out_ptr + row_out + cols, values, mask=mask)
55def _normalize_dim(dim: int, ndim: int) -> int:
56 return dim if dim >= 0 else dim + ndim
59def _use_triton_kernel(
60 inp: torch.Tensor,
61 dim: int,
62 index: torch.Tensor,
63 out: Optional[torch.Tensor],
64) -> bool:
65 if inp.device.type != "musa" or index.device != inp.device:
66 return False
67 if inp.dtype not in _SUPPORTED_DTYPES or index.dtype != torch.long:
68 return False
70 dim = _normalize_dim(dim, inp.ndim)
71 if dim != inp.ndim - 1:
72 return False
74 if not inp.is_contiguous() or not index.is_contiguous():
75 return False
76 if out is not None:
77 if (
78 out.device != inp.device
79 or out.dtype != inp.dtype
80 or not out.is_contiguous()
81 ):
82 return False
84 if index.shape[:-1] != inp.shape[:-1]:
85 return False
87 return True
90def _launch_triton(
91 inp: torch.Tensor,
92 index: torch.Tensor,
93 out: torch.Tensor,
94 dim_stride: int,
95) -> torch.Tensor:
96 inp_2d = inp.view(-1, inp.shape[-1])
97 index_2d = index.view(-1, index.shape[-1])
98 out_2d = out.view(-1, index.shape[-1])
100 M, N = index_2d.shape
101 stride_inp_row = inp_2d.stride(0)
102 stride_index_row = index_2d.stride(0)
103 stride_out_row = out_2d.stride(0)
105 grid = lambda meta: (
106 triton.cdiv(M, meta["BLOCK_M"]),
107 triton.cdiv(N, meta["BLOCK_N"]),
108 )
109 with torch_device_fn.device(out.device):
110 _gather_lastdim_kernel[grid](
111 inp_2d,
112 index_2d,
113 out_2d,
114 stride_inp_row,
115 stride_index_row,
116 stride_out_row,
117 dim_stride,
118 M,
119 N,
120 )
121 return out
124def gather(inp, dim, index, out=None, sparse_grad=False):
125 logger.debug("GEMS_MTHREADS GATHER")
126 if not _use_triton_kernel(inp, dim, index, out):
127 return default_gather(inp, dim, index, out, sparse_grad)
129 if out is None:
130 out = torch.empty_like(index, dtype=inp.dtype, device=inp.device)
132 dim_stride = inp.stride(_normalize_dim(dim, inp.ndim))
133 return _launch_triton(inp, index, out, dim_stride)
136def gather_backward(grad, self, dim, index, sparse_grad):
137 logger.debug("GEMS_MTHREADS GATHER BACKWARD")
138 return default_gather_backward(grad, self, dim, index, sparse_grad)