Coverage for src/flag_gems/runtime/backend/_cambricon/ops/gather.py: 0%
157 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1import importlib
2import logging
3import os
4from typing import Any, Callable, List, Mapping, Tuple
6import torch
8from flag_gems.utils.code_cache import code_cache_dir
9from flag_gems.utils.code_utils import IndentedBuffer
10from flag_gems.utils.shape_utils import restride_dim
12from .scatter import scatter
14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
17def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
18 code.writeline("import torch")
19 code.writeline("import triton")
20 code.writeline("import triton.language as tl")
21 code.newline()
22 code.writeline("from flag_gems.utils import libentry, libtuner")
23 code.writeline("from flag_gems import runtime")
24 code.writeline("from flag_gems.utils import triton_lang_extension as tle")
26 code.newline()
27 code.newline()
28 return code
31def generate_gather_kernel(
32 dim: int,
33 large_input: bool,
34 rank: int,
35 kernel_name: str,
36 code: IndentedBuffer,
37) -> IndentedBuffer:
38 # make the inlined function visible in the context
39 code.newline()
41 # the decorators
42 code.writeline("@libentry()")
43 code.writeline(
44 '@libtuner(configs=runtime.get_tuned_config("gather"), key=["N"], strategy=["log"])'
45 )
46 code.writeline("@triton.jit")
48 # signature
49 code.writeline(f"def {kernel_name}(")
50 with code.indent():
51 if rank > 0:
52 code.writeline("inp,")
53 code.writeline("out,")
54 code.writeline("index,")
56 stride_args = ", ".join(f"inp_stride_{i}: int" for i in range(rank))
57 code.writeline(f"{stride_args}, # stride for inp")
59 stride_args = ", ".join(f"index_stride_{i}: int" for i in range(rank))
60 code.writeline(f"{stride_args}, # stride for index")
62 shape_args = ", ".join(f"index_shape_{i}: int" for i in range(rank))
63 code.writeline(f"{shape_args}, # shape for index")
65 code.writeline("dim,")
66 code.writeline("stride_dim,")
67 code.writeline("N,")
68 code.writeline("BLOCK_SIZE: tl.constexpr,")
69 code.writeline("):")
71 # Kernel Code
72 with code.indent():
73 code.writeline("pid = tl.program_id(0)")
74 code.writeline("offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)")
75 code.writeline("mask = offsets < N")
77 # 1. Calculate inp_offsets and idx_offsets
78 if large_input:
79 code.writeline("inp_offsets = tl.zeros((BLOCK_SIZE,), dtype=tl.int64)")
80 else:
81 code.writeline("inp_offsets = tl.zeros((BLOCK_SIZE,), dtype=tl.int32)")
82 code.writeline("index_offsets = offsets")
84 # 2. snippets
85 for i in range(rank - 1, -1, -1):
86 if not (dim == 0 and i == 0):
87 code.writeline(f"mod = offsets % index_shape_{i}")
89 if i != dim:
90 # will be corrected by adding cur_index*stride_dim
91 code.writeline(f"inp_offsets += mod * inp_stride_{i}")
92 if i != 0:
93 code.writeline(f"offsets //= index_shape_{i}")
95 # Use offsets to gather
96 if large_input:
97 code.writeline(
98 "cur_index = tl.load(index + index_offsets, mask=mask, other=0)"
99 )
100 else:
101 code.writeline(
102 "cur_index = tl.load(index + index_offsets, mask=mask, other=0).to(tl.int32)"
103 )
105 code.writeline("inp_offsets += cur_index * stride_dim")
107 code.writeline("cur_inp = tl.load(inp + inp_offsets, mask=mask, other=0)")
108 code.writeline("tl.store(out + index_offsets, cur_inp, mask=mask)")
110 code.newline()
111 code.newline()
112 return code
115def parameter_for_wrapper() -> str:
116 # inp_strided, out, index, dim, stride_dim, N
117 parameters: List[str] = []
119 parameters.append("inp_strided")
120 parameters.append("out")
121 parameters.append("index")
122 parameters.append("dim")
123 parameters.append("stride_dim")
124 parameters.append("N")
126 return ", ".join(parameters)
129def generate_gather_wrapper(
130 rank: int,
131 wrapper_name: str,
132 kernel_name: str,
133 code: IndentedBuffer,
134) -> IndentedBuffer:
135 parameters: str = parameter_for_wrapper()
136 wrapper_signature: str = f"def {wrapper_name}({parameters}):"
137 code.writeline(wrapper_signature)
139 with code.indent():
140 code.writeline("inp_strides = inp_strided.stride()")
141 code.writeline("index_strides = index.stride()")
142 code.writeline("index_shapes = list(index.shape)")
144 # kernel launch
145 code.writeline("grid = lambda meta: (")
146 with code.indent():
147 code.writeline('triton.cdiv(N, meta["BLOCK_SIZE"]),')
148 code.writeline(")")
150 kernel_launch: str = f"{kernel_name}[grid]("
151 code.writeline(kernel_launch)
153 with code.indent():
154 code.writeline("inp_strided, out, index, ")
155 if rank > 0:
156 s = ", ".join(f"inp_strides[{i}]" for i in range(rank))
157 code.writeline(f"{s},")
159 s = ", ".join(f"index_strides[{i}]" for i in range(rank))
160 code.writeline(f"{s},")
162 s = ", ".join(f"index_shapes[{i}]" for i in range(rank))
163 code.writeline(f"{s},")
165 code.writeline("dim,")
166 code.writeline("stride_dim,")
167 code.writeline("N,")
168 code.writeline(")")
169 code.writeline("return out")
171 return code
174def generate_code(
175 dim: int,
176 large_input: bool,
177 inputs: Tuple[Any],
178 wrapper_name: str,
179 kernel_name: str,
180 code: IndentedBuffer,
181) -> IndentedBuffer:
182 # inputs: inp_strided, out, index, dim, stride_dim, N, large_input
183 shape = inputs[2].shape
184 rank = len(shape)
186 code = generate_imports(code)
187 code = generate_gather_kernel(dim, large_input, rank, kernel_name, code)
188 code = generate_gather_wrapper(rank, wrapper_name, kernel_name, code)
189 return code
192class GatherFunction:
193 def __init__(self):
194 self.pid = os.getpid()
195 self.overloads: Mapping[str, Callable] = {}
197 def __call__(self, *args, **kwargs):
198 rank = kwargs["rank"]
199 dim = kwargs["dim"]
200 large_input = kwargs["large_input"]
202 key = f"{self.arg_key(*args)}_{rank}_{dim}_{large_input}"
203 if key in self.overloads:
204 overload = self.overloads[key]
205 else:
206 code = IndentedBuffer()
207 code = generate_code(
208 dim,
209 large_input,
210 args,
211 "_gather_wrapper",
212 "_gather_jit_function",
213 code,
214 )
216 file_name = f"gather_rank_{key}_pid_{self.pid}.py"
218 with open(code_cache_dir() / file_name, "wt", encoding="utf-8") as f:
219 f.write(code.getvalue())
221 # load
222 spec = importlib.util.spec_from_file_location(
223 f"_gen_module_rank_{key}_pid_{self.pid}",
224 f.name,
225 )
227 m = importlib.util.module_from_spec(spec)
228 spec.loader.exec_module(m)
229 overload = getattr(m, "_gather_wrapper")
230 self.overloads[key] = overload
232 return overload(*args)
234 def arg_key(self, *args):
235 tensors = [item for item in args if torch.is_tensor(item)]
236 max_rank = max(item.ndim for item in tensors)
237 return max_rank
240_gather_func = GatherFunction()
243def gather(inp, dim, index, out=None, sparse_grad=False):
244 logger.debug("GEMS_CAMBRICON GATHER")
245 inp = inp.contiguous()
246 index = index.contiguous()
247 if out is None:
248 out = torch.empty_like(index, dtype=inp.dtype, device=inp.device)
249 out = out.contiguous()
250 stride_dim = inp.stride(dim)
252 inp_strided = restride_dim(inp, dim, index.shape)
253 N = index.numel()
255 large_input = inp.numel() * inp.element_size() > 2**31
256 rank = len(index.shape)
258 # <rank>_<dim>_<large_input> is the key of overloads
259 # large_input is only for key
260 _gather_func(
261 inp_strided,
262 out,
263 index,
264 dim,
265 stride_dim,
266 N,
267 large_input=large_input,
268 dim=dim,
269 rank=rank,
270 )
271 return out
274def gather_backward(grad, self, dim, index, sparse_grad):
275 logger.debug("GEMS_CAMBRICON GATHER BACKWARD")
276 result = grad.new_zeros(self.shape)
277 return scatter(result, dim, index, grad, reduce="add")