Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/gather.py: 0%
189 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
1import importlib
2import logging
3import os
4from typing import Any, Callable, List, Mapping, Tuple
6import torch
8from flag_gems.utils.code_cache import cache_dir
9from flag_gems.utils.code_utils import IndentedBuffer
10from flag_gems.utils.shape_utils import restride_dim
12from .scatter import scatter_
14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
17def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
18 code.writeline("import torch")
19 code.writeline("import triton")
20 code.writeline("import triton.language as tl")
21 code.writeline("import builtins")
22 code.newline()
23 code.writeline("from flag_gems.utils import libentry")
24 code.writeline("from flag_gems import runtime")
25 code.writeline("from flag_gems.utils import triton_lang_extension as tle")
27 code.newline()
28 code.newline()
29 return code
32def generate_gather_kernel(
33 rank: int,
34 kernel_name: str,
35 code: IndentedBuffer,
36) -> IndentedBuffer:
37 # make the inlined function visible in the context
38 code.newline()
40 # the autotune function
41 code.writeline("def cfggen():")
42 with code.indent():
43 code.writeline("block_m = [1, 2, 4, 8]")
44 code.writeline("block_n = [256, 512, 1024, 2048]")
45 code.writeline("configs = [")
46 with code.indent():
47 code.writeline('triton.Config({"BLOCK_M": m, "BLOCK_N": n}, num_warps=4)')
48 code.writeline("for m in block_m")
49 code.writeline("for n in block_n")
50 code.writeline("]")
51 code.writeline("return configs")
53 code.newline()
54 code.newline()
56 code.writeline("def heur_block_m(args):")
57 with code.indent():
58 code.writeline('return triton.next_power_of_2(triton.cdiv(args["M"], 12))')
60 code.newline()
62 code.writeline("def heur_block_n(args):")
63 with code.indent():
64 code.writeline('return builtins.min(triton.next_power_of_2(args["N"]), 4096)')
66 code.newline()
67 code.newline()
69 # the decorators
70 code.writeline("@libentry()")
71 # code.writeline('@triton.autotune(configs=cfggen(), key=["M", "N"])')
72 code.writeline("@triton.heuristics(")
73 with code.indent():
74 code.writeline("values={")
75 with code.indent():
76 code.writeline('"BLOCK_M": heur_block_m,')
77 code.writeline('"BLOCK_N": heur_block_n,')
78 code.writeline("},")
79 code.writeline(")")
80 code.writeline("@triton.jit")
82 # signature
83 code.writeline(f"def {kernel_name}(")
84 with code.indent():
85 if rank > 0:
86 code.writeline("inp,")
87 code.writeline("out,")
88 code.writeline("index,")
90 stride_args = ", ".join(
91 f"inp_stride_{i}: tl.constexpr" for i in range(rank)
92 )
93 code.writeline(f"{stride_args}, # stride for inp")
95 stride_args = ", ".join(
96 f"index_stride_{i}: tl.constexpr" for i in range(rank)
97 )
98 code.writeline(f"{stride_args}, # stride for index")
100 shape_args = ", ".join(
101 f"index_shape_{i}: tl.constexpr" for i in range(rank)
102 )
103 code.writeline(f"{shape_args}, # shape for index")
105 code.writeline("dim: tl.constexpr,")
106 code.writeline("stride_dim: tl.constexpr,")
107 code.writeline("M: tl.constexpr,")
108 code.writeline("N: tl.constexpr,")
109 code.writeline("BLOCK_M: tl.constexpr,")
110 code.writeline("BLOCK_N: tl.constexpr,")
111 code.writeline("):")
113 # Kernel Code
114 with code.indent():
115 code.writeline("pid_x = tle.program_id(0)")
116 code.writeline("pid_y = tle.program_id(1)")
117 code.writeline(
118 "rows_offsets = pid_x * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]"
119 )
120 code.writeline(
121 "cols_offsets = pid_y * BLOCK_N + tl.arange(0, BLOCK_N)[None, :]"
122 )
123 code.writeline("rows_mask = rows_offsets < M")
124 code.writeline("cols_mask = cols_offsets < N")
126 code.writeline("offsets = (rows_offsets * N + cols_offsets).to(tl.int64)")
127 code.writeline("mask = rows_mask & cols_mask")
129 # 1. Calculate inp_offsets and idx_offsets
130 code.writeline("inp_offsets = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int64)")
131 code.writeline("idx_offsets = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int64)")
132 code.writeline("cur_idx = rows_offsets * N + cols_offsets")
134 # 2. snippets
135 for i in range(rank):
136 code.writeline(f"mod = cur_idx % index_shape_{i}")
137 code.writeline(f"inp_offsets += mod * inp_stride_{i}")
138 code.writeline(f"idx_offsets += mod * index_stride_{i}")
139 if i != (rank - 1):
140 code.writeline(f"cur_idx //= index_shape_{i}")
142 # Use offsets to gather
143 code.writeline("cur_index = tl.load(index + idx_offsets, mask=mask, other=0)")
144 code.writeline("inp_offsets += cur_index * stride_dim")
145 code.writeline("cur_inp = tl.load(inp + inp_offsets, mask=mask, other=0)")
146 code.writeline("tl.store(out + idx_offsets, cur_inp, mask=mask)")
148 code.newline()
149 code.newline()
150 return code
153def parameter_for_wrapper() -> str:
154 # inp_strided, out, index, dim, stride_dim, M, N
155 parameters: List[str] = []
157 parameters.append("inp_strided")
158 parameters.append("out")
159 parameters.append("index")
160 parameters.append("dim")
161 parameters.append("stride_dim")
162 parameters.append("M")
163 parameters.append("N")
165 return ", ".join(parameters)
168def generate_gather_wrapper(
169 rank: int,
170 wrapper_name: str,
171 kernel_name: str,
172 code: IndentedBuffer,
173) -> IndentedBuffer:
174 parameters: str = parameter_for_wrapper()
175 wrapper_signature: str = f"def {wrapper_name}({parameters}):"
176 code.writeline(wrapper_signature)
178 with code.indent():
179 code.writeline("inp_strides = inp_strided.stride()")
180 code.writeline("index_strides = index.stride()")
181 code.writeline("index_shapes = list(index.shape)")
183 # kernel launch
184 code.writeline("grid = lambda meta: (")
185 with code.indent():
186 code.writeline('triton.cdiv(M, meta["BLOCK_M"]),')
187 code.writeline('triton.cdiv(N, meta["BLOCK_N"])')
188 code.writeline(")")
190 kernel_launch: str = f"{kernel_name}[grid]("
191 code.writeline(kernel_launch)
193 with code.indent():
194 code.writeline("inp_strided, out, index, ")
195 if rank > 0:
196 s = ", ".join(f"inp_strides[{i}]" for i in range(rank))
197 code.writeline(f"{s},")
199 s = ", ".join(f"index_strides[{i}]" for i in range(rank))
200 code.writeline(f"{s},")
202 s = ", ".join(f"index_shapes[{i}]" for i in range(rank))
203 code.writeline(f"{s},")
205 code.writeline("dim,")
206 code.writeline("stride_dim,")
207 code.writeline("M,")
208 code.writeline("N,")
209 code.writeline(")")
210 code.writeline("return out")
212 return code
215def generate_code(
216 inputs: Tuple[Any],
217 wrapper_name: str,
218 kernel_name: str,
219 code: IndentedBuffer,
220) -> IndentedBuffer:
221 # inputs: inp_strided, out, index, dim, stride_dim, M, N
222 shape = inputs[2].shape
223 rank = len(shape)
225 code = generate_imports(code)
226 code = generate_gather_kernel(rank, kernel_name, code)
227 code = generate_gather_wrapper(rank, wrapper_name, kernel_name, code)
228 return code
231class GatherFunction:
232 def __init__(self):
233 self.pid = os.getpid()
234 self.overloads: Mapping[str, Callable] = {}
236 def __call__(self, *args, **kwargs):
237 key = f"{self.arg_key(*args)}"
238 if key in self.overloads:
239 overload = self.overloads[key]
240 else:
241 code = IndentedBuffer()
242 code = generate_code(
243 args,
244 "_gather_wrapper",
245 "_gather_jit_function",
246 code,
247 )
249 file_name = f"gather_rank_{key}_pid_{self.pid}.py"
251 with open(cache_dir() / file_name, "wt", encoding="utf-8") as f:
252 f.write(code.getvalue())
254 # load
255 spec = importlib.util.spec_from_file_location(
256 f"_gen_module_rank_{key}_pid_{self.pid}",
257 f.name,
258 )
260 m = importlib.util.module_from_spec(spec)
261 spec.loader.exec_module(m)
262 overload = getattr(m, "_gather_wrapper")
263 self.overloads[key] = overload
265 return overload(*args, **kwargs)
267 def arg_key(self, *args):
268 tensors = [item for item in args if torch.is_tensor(item)]
269 max_rank = max(item.ndim for item in tensors)
270 return max_rank
273_gather_func = GatherFunction()
276def gather(inp, dim, index, out=None, sparse_grad=False):
277 logger.debug("GEMS GATHER")
278 inp = inp.contiguous()
279 index = index.contiguous()
280 if out is None:
281 out = torch.empty_like(index, dtype=inp.dtype, device=inp.device)
282 out = out.contiguous()
283 stride_dim = inp.stride(dim)
285 inp_strided = restride_dim(inp, dim, index.shape)
286 # plain_idx = torch.arange(0, index.numel(), device=inp.device).reshape(index.shape)
287 N = list(index.shape)[index.ndim - 1]
288 M = index.numel() // N
290 _gather_func(inp_strided, out, index, dim, stride_dim, M, N)
291 return out
294def gather_backward(grad, self, dim, index, sparse_grad):
295 logger.debug("GEMS GATHER BACKWARD")
296 result = grad.new_zeros(self.shape)
297 return scatter_(result, dim, index, grad, reduce="add")