Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/scatter.py: 0%
239 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
1import importlib
2import logging
3import os
4from typing import Any, Callable, List, Mapping, Tuple
6import torch
8from flag_gems.utils.code_cache import code_cache_dir
9from flag_gems.utils.code_utils import IndentedBuffer, write_atomic
10from flag_gems.utils.shape_utils import (
11 MemOverlap,
12 has_internal_overlapping,
13 restride_dim,
14)
16logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
19def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
20 code.writeline("import torch")
21 code.writeline("import triton")
22 code.writeline("import triton.language as tl")
23 code.newline()
24 code.writeline("from flag_gems.utils import libentry")
25 code.writeline("from flag_gems import runtime")
26 code.writeline("import flag_gems")
27 # code.writeline("from flag_gems.utils import triton_lang_extension as tle")
28 code.newline()
29 code.newline()
30 return code
33def generate_scatter_kernel(
34 rank: int,
35 kernel_name: str,
36 code: IndentedBuffer,
37) -> IndentedBuffer:
38 # make the inlined function visible in the context
39 code.newline()
41 # the autotune function
43 code.writeline("def heur_block(args):")
44 with code.indent():
45 code.writeline(
46 'return triton.next_power_of_2(triton.cdiv(triton.cdiv(args["N"], 12), 4))'
47 ) # LOOP = 4
48 code.newline()
49 code.newline()
51 code.writeline("def loop_count(args):")
52 with code.indent():
53 code.writeline("return 4")
54 code.newline()
55 code.newline()
57 # the decorators
58 code.writeline("@libentry()")
59 code.writeline("@triton.heuristics(")
60 with code.indent():
61 code.writeline("{")
62 with code.indent():
63 code.writeline('"BLOCK": heur_block,')
64 code.writeline('"LOOP": loop_count,')
65 code.writeline("}")
66 code.writeline(")")
67 inp_stride_vars = ",".join(f"'inp_stride_{i}'" for i in range(rank))
68 index_stride_vars = ",".join(f"'index_stride_{i}'" for i in range(rank))
69 src_stride_vars = ",".join(f"'src_stride_{i}'" for i in range(rank))
70 shape_vars = ",".join(f"'shape_{i}'" for i in range(rank))
71 code.writeline(
72 f"@triton.jit(do_not_specialize=['N','stride_dim','inp_size_dim',"
73 f"{inp_stride_vars},{index_stride_vars},{src_stride_vars},{shape_vars}])"
74 )
76 # signature
77 code.writeline(f"def {kernel_name}(")
78 with code.indent():
79 if rank > 0:
80 code.writeline("src_strided,")
81 code.writeline("index,")
82 code.writeline("inp,")
83 code.writeline("out,")
85 stride_args = ", ".join(f"inp_stride_{i}: int" for i in range(rank))
86 code.writeline(f"{stride_args}, # stride for inp")
88 stride_args = ", ".join(f"index_stride_{i}: int" for i in range(rank))
89 code.writeline(f"{stride_args}, # stride for index")
91 stride_args = ", ".join(f"src_stride_{i}: int" for i in range(rank))
92 code.writeline(f"{stride_args}, # stride for src")
94 shape_args = ", ".join(f"shape_{i}: int" for i in range(rank))
95 code.writeline(f"{shape_args}, # shape")
96 code.writeline("inp_size_dim,")
97 code.writeline("stride_dim,")
98 code.writeline("N,")
99 # reduce options
100 code.writeline("IS_ADD: tl.constexpr,")
101 code.writeline("IS_MUL: tl.constexpr,")
102 code.writeline("BLOCK: tl.constexpr,")
103 code.writeline("LOOP: tl.constexpr,")
104 code.writeline("INT32_OFFSET: tl.constexpr")
106 code.writeline("):")
108 # Kernel Code
109 with code.indent():
110 code.writeline("pid = tl.program_id(0)")
111 code.writeline("if not INT32_OFFSET:")
112 with code.indent():
113 code.writeline("pid = pid.to(tl.int64)")
114 code.writeline("offsets = pid * LOOP * BLOCK + tl.arange(0, BLOCK)")
116 # 1. Calculate inp_offsets and idx_offsets
117 code.writeline("for loop_iter in tl.static_range(LOOP):")
118 with code.indent():
119 code.writeline("mask = offsets < N")
120 code.writeline("cur_idx = offsets")
121 code.writeline("if INT32_OFFSET:")
122 with code.indent():
123 code.writeline("inp_offsets = tl.zeros((BLOCK, ), dtype=tl.int32)")
124 code.writeline("idx_offsets = tl.zeros((BLOCK, ), dtype=tl.int32)")
125 code.writeline("src_offsets = tl.zeros((BLOCK, ), dtype=tl.int32)")
126 code.writeline("else:")
127 with code.indent():
128 code.writeline("inp_offsets = tl.zeros((BLOCK, ), dtype=tl.int64)")
129 code.writeline("idx_offsets = tl.zeros((BLOCK, ), dtype=tl.int64)")
130 code.writeline("src_offsets = tl.zeros((BLOCK, ), dtype=tl.int64)")
131 for i in range(rank)[::-1]:
132 code.writeline("if INT32_OFFSET:")
133 with code.indent():
134 code.writeline(f"shape_{i} = shape_{i}.to(tl.int32)")
135 code.writeline(f"inp_stride_{i} = inp_stride_{i}.to(tl.int32)")
136 code.writeline(f"index_stride_{i} = index_stride_{i}.to(tl.int32)")
137 code.writeline(f"src_stride_{i} = src_stride_{i}.to(tl.int32)")
138 code.writeline(f"mod = cur_idx % shape_{i}")
139 code.writeline(f"inp_offsets += mod * inp_stride_{i}")
140 code.writeline(f"idx_offsets += mod * index_stride_{i}")
141 code.writeline(f"src_offsets += mod * src_stride_{i}")
142 if i != 0:
143 code.writeline(f"cur_idx = cur_idx // shape_{i}")
145 # 2. Use offsets to scatter
146 code.writeline(
147 "cur_src = tl.load(src_strided + src_offsets, mask=mask, other=0)"
148 )
149 code.writeline(
150 "cur_index = tl.load(index + idx_offsets, mask=mask, other=0)"
151 )
152 code.writeline("if INT32_OFFSET:")
153 with code.indent():
154 code.writeline("cur_index = cur_index.to(tl.int32)")
155 code.writeline("stride_dim = stride_dim.to(tl.int32)")
157 code.writeline("dim_offsets = cur_index * stride_dim")
158 code.writeline("inp_offsets += dim_offsets")
159 code.newline()
160 code.writeline("if IS_ADD: ")
161 with code.indent():
162 code.writeline(
163 "tl.atomic_add(out + inp_offsets, cur_src, mask=mask, sem='relaxed')"
164 )
165 code.writeline("elif IS_MUL: ")
166 with code.indent():
167 code.writeline(
168 "tl.atomic_mul(out + inp_offsets, cur_src, mask=mask, sem='relaxed')"
169 )
171 code.writeline("else: ")
172 with code.indent():
173 code.writeline("tl.store(out + inp_offsets, cur_src, mask=mask)")
175 code.writeline("offsets += BLOCK")
177 code.newline()
178 code.newline()
179 return code
182def parameter_for_wrapper() -> str:
183 # src_strided, index, inp, out, dim, M, N, reduce
184 parameters: List[str] = []
186 parameters.append("src_strided")
187 parameters.append("index")
188 parameters.append("inp")
189 parameters.append("out")
190 parameters.append("dim_size")
191 parameters.append("dim_stride")
192 parameters.append("N")
193 parameters.append("reduce: tl.constexpr=None")
194 parameters.append("int32_offset: tl.constexpr=None")
196 return ", ".join(parameters)
199def generate_destination_passing_wrapper(
200 rank: int,
201 wrapper_name: str,
202 kernel_name: str,
203 code: IndentedBuffer,
204) -> IndentedBuffer:
205 parameters: str = parameter_for_wrapper()
206 wrapper_signature: str = f"def {wrapper_name}({parameters}):"
207 code.writeline(wrapper_signature)
209 with code.indent():
210 code.writeline("inp_strides = list(inp.stride())")
211 code.writeline("index_strides = index.stride()")
212 code.writeline("src_strides = src_strided.stride()")
213 code.writeline("index_shapes = list(index.shape)")
214 code.writeline("inp_size_dim = dim_size")
215 code.writeline("stride_dim = dim_stride")
217 code.writeline('IS_ADD = reduce == "add"')
218 code.writeline('IS_MUL = reduce == "multiply"')
219 code.writeline("int32_offset = int32_offset or True")
221 # kernel launch
222 code.writeline("grid = lambda meta: (")
223 with code.indent():
224 code.writeline('triton.cdiv(N, meta["BLOCK"] * meta["LOOP"]), ')
225 code.writeline(")")
227 kernel_launch: str = f"{kernel_name}[grid]("
228 code.writeline(kernel_launch)
230 with code.indent():
231 code.writeline("src_strided, index, inp, out, ")
232 if rank > 0:
233 s = ", ".join(f"inp_strides[{i}]" for i in range(rank))
234 code.writeline(f"{s},")
236 s = ", ".join(f"index_strides[{i}]" for i in range(rank))
237 code.writeline(f"{s},")
239 s = ", ".join(f"src_strides[{i}]" for i in range(rank))
240 code.writeline(f"{s},")
242 s = ", ".join(f"index_shapes[{i}]" for i in range(rank))
243 code.writeline(f"{s},")
245 code.writeline("inp_size_dim,")
246 code.writeline("stride_dim,")
247 code.writeline("N,")
248 # reduce options
249 code.writeline("IS_ADD,")
250 code.writeline("IS_MUL,")
251 code.writeline("INT32_OFFSET=int32_offset,")
252 # code.writeline("buffer_size_limit=512,")
253 # code.writeline("isCloseUnrollControl=True,")
255 code.writeline(")")
256 code.writeline("return out")
258 return code
261def generate_code(
262 inputs: Tuple[Any],
263 wrapper_name: str,
264 kernel_name: str,
265 code: IndentedBuffer,
266) -> IndentedBuffer:
267 # inputs: [src_strided, index, inp, out, dim, M, N, reduce]
268 shape = inputs[1].shape
269 rank = len(shape)
271 code = generate_imports(code)
272 code = generate_scatter_kernel(rank, kernel_name, code)
273 code = generate_destination_passing_wrapper(rank, wrapper_name, kernel_name, code)
274 return code
277class ScatterFunction:
278 def __init__(self):
279 self.pid = os.getpid()
280 self.overloads: Mapping[str, Callable] = {}
282 def __call__(self, *args, **kwargs):
283 key = f"{self.arg_key(*args)}"
284 if key in self.overloads:
285 overload = self.overloads[key]
286 else:
287 code = IndentedBuffer()
288 code = generate_code(
289 args,
290 "_scatter_wrapper",
291 "_scatter_jit_function",
292 code,
293 )
295 file_name = f"scatter_rank_{key}.py"
296 file_path = code_cache_dir() / file_name
297 write_atomic(file_path, code.getvalue())
299 # load
300 spec = importlib.util.spec_from_file_location(
301 f"_gen_module_rank_{key}",
302 file_path,
303 )
305 m = importlib.util.module_from_spec(spec)
306 spec.loader.exec_module(m)
307 overload = getattr(m, "_scatter_wrapper")
308 self.overloads[key] = overload
310 return overload(*args, **kwargs)
312 def arg_key(self, *args):
313 tensors = [item for item in args if torch.is_tensor(item)]
314 max_rank = max(item.ndim for item in tensors)
315 return max_rank
318_scatter_func = ScatterFunction()
321def scatter(inp, dim, index, src, reduce=None):
322 logger.debug("GEMS SCATTER")
323 out = inp.clone()
325 if reduce is not None:
326 assert inp.dtype not in (
327 torch.bfloat16,
328 ), "Unsupported operation: reduce scatter bfloat tensors."
330 if has_internal_overlapping(out) == MemOverlap.Yes:
331 out = out.contiguous()
333 src_strided = src.as_strided(index.shape, src.stride())
334 inp_restrided = restride_dim(inp, dim, index.shape)
335 dim_size = inp.size(dim)
336 dim_stride = inp.stride(dim)
337 N = index.numel()
339 int32_size_dim = lambda x: x.stride(dim) * x.size(dim) < 2**32
340 use_int32_offset = all(map(int32_size_dim, (inp, index, src)))
341 _scatter_func(
342 src_strided,
343 index,
344 inp_restrided,
345 out,
346 dim_size,
347 dim_stride,
348 N,
349 reduce,
350 int32_offset=use_int32_offset,
351 )
353 return out
356def scatter_(inp, dim, index, src, reduce=None):
357 logger.debug("GEMS SCATTER_")
358 out = inp
360 if reduce is not None:
361 assert inp.dtype not in (
362 torch.bfloat16,
363 ), "Unsupported operation: reduce scatter bfloat tensors."
365 assert (
366 has_internal_overlapping(out) != MemOverlap.Yes
367 ), "Unsupported operation: trying to inplace write to an internally overlapping tensor."
369 src_restrided = src.as_strided(index.shape, src.stride())
370 inp_restrided = restride_dim(inp, dim, index.shape)
371 dim_size = inp.size(dim)
372 dim_stride = inp.stride(dim)
373 N = index.numel()
375 int32_size_dim = lambda x: x.stride(dim) * x.size(dim) < 2**32
376 use_int32_offset = all(map(int32_size_dim, (inp, index, src)))
377 _scatter_func(
378 src_restrided,
379 index,
380 inp_restrided,
381 out,
382 dim_size,
383 dim_stride,
384 N,
385 reduce,
386 int32_offset=use_int32_offset,
387 )
389 return inp