Coverage for src/flag_gems/runtime/backend/_cambricon/ops/scatter.py: 0%
184 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
1import importlib
2import logging
3import os
4from typing import Any, Callable, List, Mapping, Tuple
6import torch
8from flag_gems.utils.code_cache import code_cache_dir
9from flag_gems.utils.code_utils import IndentedBuffer
11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
14def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
15 code.writeline("import torch")
16 code.writeline("import triton")
17 code.writeline("import triton.language as tl")
18 code.newline()
19 code.writeline("from flag_gems.utils import libentry, libtuner")
20 code.writeline("from flag_gems import runtime")
21 code.writeline("from flag_gems.utils import triton_lang_extension as tle")
22 code.newline()
23 code.newline()
24 return code
27def generate_scatter_kernel(
28 rank: int,
29 dim: int,
30 large_tensor: bool,
31 kernel_name: str,
32 code: IndentedBuffer,
33) -> IndentedBuffer:
34 # make the inlined function visible in the context
35 code.newline()
37 # the autotune function
39 code.newline()
40 code.newline()
42 # the decorators
43 code.writeline("@libentry()")
44 code.writeline(
45 '@libtuner(configs=runtime.get_tuned_config("scatter"), key=["N"], strategy=["log"],'
46 )
47 code.writeline(' restore_value=["out"], )')
49 code.writeline("@triton.jit")
51 # signature
52 code.writeline(f"def {kernel_name}(")
53 with code.indent():
54 if rank > 0:
55 code.writeline("src,")
56 code.writeline("index,")
57 code.writeline("inp,")
58 code.writeline("out,")
60 stride_args = ", ".join(f"inp_stride_{i}: int" for i in range(rank))
61 code.writeline(f"{stride_args}, # stride for inp")
63 stride_args = ", ".join(f"src_stride_{i}: int" for i in range(rank))
64 code.writeline(f"{stride_args}, # stride for src")
66 shape_args = ", ".join(f"index_shape_{i}: int" for i in range(rank))
67 code.writeline(f"{shape_args}, # shape for index")
69 code.writeline("dim,")
70 code.writeline("stride_dim,")
71 code.writeline("N,")
72 # reduce options
73 code.writeline("IS_ADD: tl.constexpr,")
74 code.writeline("IS_MUL: tl.constexpr,")
75 code.writeline("BLOCK_SIZE: tl.constexpr,")
77 code.writeline("):")
79 # Kernel Code
80 with code.indent():
81 code.writeline("pid = tl.program_id(0)")
82 code.writeline("offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)")
83 code.writeline("mask = offsets < N")
85 # 1. Calculate inp_offsets and src_offsets
86 if large_tensor:
87 code.writeline("inp_offsets = tl.zeros((BLOCK_SIZE,), dtype=tl.int64)")
88 code.writeline("src_offsets = tl.zeros((BLOCK_SIZE,), dtype=tl.int64)")
89 else:
90 code.writeline("inp_offsets = tl.zeros((BLOCK_SIZE,), dtype=tl.int32)")
91 code.writeline("src_offsets = tl.zeros((BLOCK_SIZE,), dtype=tl.int32)")
93 code.writeline("cur_idx = offsets")
95 # 2. snippets
96 for i in range(rank - 1, -1, -1):
97 code.writeline(f"mod = cur_idx % index_shape_{i}")
98 if dim != i:
99 code.writeline(f"inp_offsets += mod * inp_stride_{i}")
100 code.writeline(f"src_offsets += mod * src_stride_{i}")
101 # the last "//" should be optimized out
102 code.writeline(f"cur_idx = cur_idx // index_shape_{i}")
104 # 3. Use offsets to scatter
105 code.writeline("cur_src = tl.load(src + src_offsets, mask=mask, other=0)")
106 if large_tensor:
107 code.writeline("cur_index = tl.load(index + offsets, mask=mask, other=0)")
108 else:
109 code.writeline(
110 "cur_index = tl.load(index + offsets, mask=mask, other=0).to(tl.int32)"
111 )
112 code.writeline("inp_offsets += cur_index * stride_dim")
114 code.newline()
115 code.writeline("if IS_ADD: ")
116 with code.indent():
117 code.writeline("cur_inp = tl.load(inp + inp_offsets, mask=mask, other=0)")
118 code.writeline("res = cur_inp + cur_src")
119 code.writeline("tl.store(out + inp_offsets, res, mask=mask)")
121 code.writeline("elif IS_MUL: ")
122 with code.indent():
123 code.writeline("cur_inp = tl.load(inp + inp_offsets, mask=mask, other=0)")
124 code.writeline("res = cur_inp * cur_src")
125 code.writeline("tl.store(out + inp_offsets, res, mask=mask)")
127 code.writeline("else: ")
128 with code.indent():
129 code.writeline("tl.store(out + inp_offsets, cur_src, mask=mask)")
131 code.newline()
132 code.newline()
133 return code
136def parameter_for_wrapper() -> str:
137 # src, index, inp, out, dim, reduce, N
138 parameters: List[str] = []
140 parameters.append("src")
141 parameters.append("index")
142 parameters.append("inp")
143 parameters.append("out")
144 parameters.append("dim")
145 parameters.append("reduce")
146 parameters.append("N")
148 return ", ".join(parameters)
151def generate_destination_passing_wrapper(
152 rank: int,
153 wrapper_name: str,
154 kernel_name: str,
155 code: IndentedBuffer,
156) -> IndentedBuffer:
157 parameters: str = parameter_for_wrapper()
158 wrapper_signature: str = f"def {wrapper_name}({parameters}):"
159 code.writeline(wrapper_signature)
161 with code.indent():
162 code.writeline("inp_strides = list(inp.stride())")
163 code.writeline("src_strides = src.stride()")
164 code.writeline("index_shapes = list(index.shape)")
165 code.writeline("stride_dim = inp_strides[dim]")
166 code.writeline("inp_strides[dim] = 0")
168 code.writeline('IS_ADD = reduce == "add"')
169 code.writeline('IS_MUL = reduce == "multiply"')
171 # kernel launch
172 code.writeline("grid = lambda meta: (")
173 with code.indent():
174 code.writeline('triton.cdiv(N, meta["BLOCK_SIZE"]),')
175 code.writeline(")")
177 kernel_launch: str = f"{kernel_name}[grid]("
178 code.writeline(kernel_launch)
180 with code.indent():
181 code.writeline("src, index, inp, out, ")
182 if rank > 0:
183 s = ", ".join(f"inp_strides[{i}]" for i in range(rank))
184 code.writeline(f"{s},")
186 s = ", ".join(f"src_strides[{i}]" for i in range(rank))
187 code.writeline(f"{s},")
189 s = ", ".join(f"index_shapes[{i}]" for i in range(rank))
190 code.writeline(f"{s},")
192 code.writeline("dim,")
193 code.writeline("stride_dim,")
194 code.writeline("N,")
195 # reduce options
196 code.writeline("IS_ADD,")
197 code.writeline("IS_MUL,")
198 code.writeline(")")
199 code.writeline("return out")
201 return code
204def generate_code(
205 rank: int,
206 dim: int,
207 large_input: bool,
208 inputs: Tuple[Any],
209 wrapper_name: str,
210 kernel_name: str,
211 code: IndentedBuffer,
212) -> IndentedBuffer:
213 # inputs: [src, index, inp, out, dim, reduce, N]
214 shape = inputs[1].shape
215 rank = len(shape)
217 code = generate_imports(code)
218 code = generate_scatter_kernel(rank, dim, large_input, kernel_name, code)
219 code = generate_destination_passing_wrapper(rank, wrapper_name, kernel_name, code)
220 return code
223class ScatterFunction:
224 def __init__(self):
225 self.pid = os.getpid()
226 self.overloads: Mapping[str, Callable] = {}
228 def __call__(self, *args, **kwargs):
229 rank = kwargs["rank"]
230 dim = kwargs["dim"]
231 large_tensor = kwargs["large_tensor"]
233 key = f"{self.arg_key(*args)}_{rank}_{dim}_{large_tensor}"
234 if key in self.overloads:
235 overload = self.overloads[key]
236 else:
237 code = IndentedBuffer()
238 code = generate_code(
239 rank,
240 dim,
241 large_tensor,
242 args,
243 "_scatter_wrapper",
244 "_scatter_jit_function",
245 code,
246 )
248 file_name = f"scatter_rank_{key}_pid_{self.pid}.py"
250 with open(code_cache_dir() / file_name, "wt", encoding="utf-8") as f:
251 f.write(code.getvalue())
253 # load
254 spec = importlib.util.spec_from_file_location(
255 f"_gen_module_rank_{key}_pid_{self.pid}",
256 f.name,
257 )
259 m = importlib.util.module_from_spec(spec)
260 spec.loader.exec_module(m)
261 overload = getattr(m, "_scatter_wrapper")
262 self.overloads[key] = overload
264 return overload(*args)
266 def arg_key(self, *args):
267 tensors = [item for item in args if torch.is_tensor(item)]
268 max_rank = max(item.ndim for item in tensors)
269 return max_rank
272_scatter_func = ScatterFunction()
275def scatter(inp, dim, index, src, reduce=None):
276 logger.debug("GEMS_CAMBRICON SCATTER")
277 inp = inp.contiguous()
278 index = index.contiguous()
279 src = src.contiguous()
280 out = inp.clone()
282 N = index.numel()
284 large_tensor = (src.numel() * src.element_size() > 2**31) or (
285 out.numel() * out.element_size() > 2**31
286 )
288 # <rank>_<dim>_<large_tensor> is part of the key of overloads
289 _scatter_func(
290 src,
291 index,
292 inp,
293 out,
294 dim,
295 reduce,
296 N,
297 rank=len(index.shape),
298 large_tensor=large_tensor,
299 dim=dim,
300 )
301 return out
304def scatter_(inp, dim, index, src, reduce=None):
305 logger.debug("GEMS_CAMBRICON SCATTER_")
306 inp = inp.contiguous()
307 index = index.contiguous()
308 src = src.contiguous()
309 out = inp
311 N = index.numel()
313 large_tensor = (src.numel() * src.element_size() > 2**31) or (
314 out.numel() * out.element_size() > 2**31
315 )
317 _scatter_func(
318 src,
319 index,
320 inp,
321 out,
322 dim,
323 reduce,
324 N,
325 rank=len(index.shape),
326 large_tensor=large_tensor,
327 dim=dim,
328 )
330 return inp