Coverage for src/flag_gems/ops/gather.py: 18%
130 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
1import importlib
2import logging
3import os
4from typing import Any, Callable, Mapping, Tuple
6import torch
8from flag_gems.ops.scatter import scatter_
9from flag_gems.utils.code_cache import code_cache_dir
10from flag_gems.utils.code_utils import IndentedBuffer, write_atomic
11from flag_gems.utils.shape_utils import restride_dim
13logger = logging.getLogger(__name__)
16def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
17 code.writeline("import torch")
18 code.writeline("import triton")
19 code.writeline("import triton.language as tl")
20 code.newline()
21 code.writeline("from flag_gems.utils import libentry")
22 code.writeline("from flag_gems import runtime")
23 code.writeline("from flag_gems.utils import triton_lang_extension as tle")
25 code.newline()
26 code.newline()
27 return code
30def generate_gather_kernel(
31 rank: int,
32 kernel_name: str,
33 code: IndentedBuffer,
34) -> IndentedBuffer:
35 # make the inlined function visible in the context
36 code.newline()
38 code.writeline("@libentry()")
39 code.writeline("@triton.heuristics({'BLOCK_SIZE_N': lambda args: 512})")
40 code.writeline("@triton.jit")
41 code.writeline(f"def {kernel_name}(")
42 with code.indent():
43 args = [
44 "inp, ",
45 "index, ",
46 "out, ",
47 ]
48 args += [f"inp_shape{i}," for i in range(rank)]
49 args += [f"index_shape{i}, " for i in range(rank)]
50 args += [f"out_shape{i}, " for i in range(rank)]
51 args += [f"inp_stride{i}, " for i in range(rank)]
52 args += [f"index_stride{i}, " for i in range(rank)]
53 args += [f"out_stride{i}, " for i in range(rank)]
54 args += ["dim, ", "dim_stride, ", "N, ", "BLOCK_SIZE_N: tl.constexpr, "]
55 code.writelines(args)
56 code.writeline("):")
58 with code.indent():
59 code.writeline("pid = tle.program_id(0)")
60 code.writeline(
61 "offset = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)"
62 )
63 code.newline()
64 code.writeline("cur_offset = offset")
65 for i in range(rank - 1, -1, -1):
66 code.writeline(f"index_idx{i} = cur_offset % index_shape{i}")
67 code.writeline(f"cur_offset = cur_offset // index_shape{i}")
68 code.newline()
69 comp = [f"index_idx{i} * index_stride{i}" for i in range(rank)]
70 code.writeline(f"index_offset = {' + '.join(comp)}")
71 code.writeline("mask = offset < N")
72 code.writeline("cur_index = tl.load(index + index_offset, mask=mask, other=0)")
73 code.newline()
74 comp = [f"index_idx{i} * inp_stride{i}" for i in range(rank)]
75 code.writeline(f"inp_offset = {' + '.join(comp)}")
76 code.writeline("inp_offset += cur_index * dim_stride")
77 code.writeline("cur_inp = tl.load(inp + inp_offset, mask=mask, other=0)")
78 code.newline()
79 comp = [f"index_idx{i} * out_stride{i}" for i in range(rank)]
80 code.writeline(f"out_offset = {' + '.join(comp)}")
81 code.writeline("tl.store(out + out_offset, value=cur_inp, mask=mask)")
83 code.newline()
84 code.newline()
85 return code
88def generate_gather_wrapper(
89 rank: int,
90 wrapper_name: str,
91 kernel_name: str,
92 code: IndentedBuffer,
93) -> IndentedBuffer:
94 code.writeline(f"def {wrapper_name}(inp, dim, index, out, dim_stride, N):")
95 with code.indent():
96 code.writeline("inp_shape = inp.shape")
97 code.writeline("inp_stride = inp.stride()")
98 code.writeline("index_shape = index.shape")
99 code.writeline("index_stride = index.stride()")
100 code.writeline("out_shape = out.shape")
101 code.writeline("out_stride = out.stride()")
102 code.writeline("grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE_N']), )")
103 code.writeline(f"{kernel_name}[grid](")
104 with code.indent():
105 args = [
106 "inp, ",
107 "index, ",
108 "out, ",
109 ]
110 args += [f"inp_shape[{i}], " for i in range(rank)]
111 args += [f"index_shape[{i}], " for i in range(rank)]
112 args += [f"out_shape[{i}], " for i in range(rank)]
113 args += [f"inp_stride[{i}], " for i in range(rank)]
114 args += [f"index_stride[{i}], " for i in range(rank)]
115 args += [f"out_stride[{i}], " for i in range(rank)]
116 args += [
117 "dim, ",
118 "dim_stride, ",
119 "N, ",
120 ]
121 code.writelines(args)
122 code.writeline(")")
123 code.writeline("return out")
124 code.newline()
125 code.newline()
126 return code
129def generate_code(
130 inputs: Tuple[Any],
131 wrapper_name: str,
132 kernel_name: str,
133 code: IndentedBuffer,
134) -> IndentedBuffer:
135 rank = inputs[0].ndim
137 code = generate_imports(code)
138 code = generate_gather_kernel(rank, kernel_name, code)
139 code = generate_gather_wrapper(rank, wrapper_name, kernel_name, code)
140 return code
143class GatherFunction:
144 def __init__(self):
145 self.pid = os.getpid()
146 self.overloads: Mapping[str, Callable] = {}
148 def __call__(self, *args, **kwargs):
149 key = f"{self.arg_key(*args)}"
150 if key in self.overloads:
151 overload = self.overloads[key]
152 else:
153 code = IndentedBuffer()
154 code = generate_code(
155 args,
156 "_gather_wrapper",
157 "_gather_flaggems_jit_function",
158 code,
159 )
161 file_name = f"gather_rank_{key}.py"
162 file_path = code_cache_dir() / file_name
163 write_atomic(file_path, code.getvalue())
165 # load
166 spec = importlib.util.spec_from_file_location(
167 f"_gen_module_rank_{key}",
168 file_path,
169 )
171 m = importlib.util.module_from_spec(spec)
172 spec.loader.exec_module(m)
173 overload = getattr(m, "_gather_wrapper")
174 self.overloads[key] = overload
176 return overload(*args, **kwargs)
178 def arg_key(self, *args):
179 return args[0].ndim
182_gather_func = GatherFunction()
185def gather(inp, dim, index, out=None, sparse_grad=False):
186 logger.debug("GEMS GATHER")
187 if out is None:
188 out = torch.empty_like(index, dtype=inp.dtype, device=inp.device)
189 dim_stride = inp.stride(dim)
190 inp_strided = restride_dim(inp, dim, index.shape)
191 N = index.numel()
192 _gather_func(inp_strided, dim, index, out, dim_stride, N)
193 return out
196def gather_backward(grad, self, dim, index, sparse_grad):
197 logger.debug("GEMS GATHER BACKWARD")
198 result = grad.new_zeros(self.shape)
199 return scatter_(result, dim, index, grad, reduce="add")