Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/scatter_add_.py: 0%
255 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
1import importlib
2import logging
3import os
4from typing import Any, Callable, List, Mapping, Tuple
6import torch
7import triton
8import triton.language as tl
10from flag_gems.utils import dim_compress
11from flag_gems.utils.code_cache import code_cache_dir
12from flag_gems.utils.code_utils import IndentedBuffer
13from flag_gems.utils.shape_utils import restride_dim
15logger = logging.getLogger(__name__)
18@triton.jit
19def scatter_add_kernel_1(
20 index_dim_n,
21 inp_dim_n,
22 out_ptr,
23 index_ptr,
24 src_ptr,
25 n_elements,
26 BLOCK_SIZE: tl.constexpr,
27 LOOP: tl.constexpr,
28):
29 pid = tl.program_id(0)
30 block_start = pid * BLOCK_SIZE * LOOP
31 arange = tl.arange(0, BLOCK_SIZE)
32 offsets = block_start + arange
33 mask = offsets < n_elements
34 for loop_iter in tl.static_range(LOOP):
35 src_index_offsets = block_start + arange
36 src_tensor = tl.load(src_ptr + src_index_offsets, mask=mask, other=0)
37 index_tensor = tl.load(index_ptr + src_index_offsets, mask=mask, other=0)
38 out_offsets = src_index_offsets // index_dim_n * inp_dim_n + index_tensor
39 tl.atomic_add(out_ptr + out_offsets, src_tensor, mask=mask, sem="relaxed")
40 block_start += BLOCK_SIZE
43def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
44 code.writeline("import torch")
45 code.writeline("import triton")
46 code.writeline("import triton.language as tl")
47 code.newline()
48 code.writeline("from flag_gems.utils import libentry")
49 code.writeline("from flag_gems import runtime")
50 code.writeline("import flag_gems")
51 code.newline()
52 code.newline()
53 return code
56def generate_scatter_kernel(
57 rank: int,
58 kernel_name: str,
59 code: IndentedBuffer,
60) -> IndentedBuffer:
61 # make the inlined function visible in the context
62 code.newline()
64 # the autotune function
65 code.writeline("def heur_block(args):")
66 with code.indent():
67 code.writeline("if(flag_gems.vendor_name in ['metax', 'iluvatar']):")
68 with code.indent():
69 code.writeline("return 256")
70 code.writeline("return 128")
71 code.newline()
72 code.newline()
74 code.writeline("def loop_count(args):")
75 with code.indent():
76 code.writeline("return 1")
77 code.newline()
78 code.newline()
80 # the decorators
81 code.writeline("@libentry()")
82 code.writeline("@triton.heuristics(")
83 with code.indent():
84 code.writeline("{")
85 with code.indent():
86 code.writeline('"BLOCK": heur_block,')
87 code.writeline('"LOOP": loop_count,')
88 code.writeline("}")
89 code.writeline(")")
90 inp_stride_vars = ",".join(f"'inp_stride_{i}'" for i in range(rank))
91 index_stride_vars = ",".join(f"'index_stride_{i}'" for i in range(rank))
92 src_stride_vars = ",".join(f"'src_stride_{i}'" for i in range(rank))
93 shape_vars = ",".join(f"'shape_{i}'" for i in range(rank))
94 code.writeline(
95 f"@triton.jit(do_not_specialize=['N','stride_dim','inp_size_dim',"
96 f"{inp_stride_vars},{index_stride_vars},{src_stride_vars},{shape_vars}])"
97 )
99 # signature
100 code.writeline(f"def {kernel_name}(")
101 with code.indent():
102 if rank > 0:
103 code.writeline("src_strided,")
104 code.writeline("index,")
105 code.writeline("inp,")
106 code.writeline("out,")
108 stride_args = ", ".join(f"inp_stride_{i}: int" for i in range(rank))
109 code.writeline(f"{stride_args}, # stride for inp")
111 stride_args = ", ".join(f"index_stride_{i}: int" for i in range(rank))
112 code.writeline(f"{stride_args}, # stride for index")
114 stride_args = ", ".join(f"src_stride_{i}: int" for i in range(rank))
115 code.writeline(f"{stride_args}, # stride for src")
117 shape_args = ", ".join(f"shape_{i}: int" for i in range(rank))
118 code.writeline(f"{shape_args}, # shape")
119 code.writeline("inp_size_dim,")
120 code.writeline("stride_dim,")
121 code.writeline("N,")
122 code.writeline("BLOCK: tl.constexpr,")
123 code.writeline("LOOP: tl.constexpr,")
125 code.writeline("):")
127 # Kernel Code
128 with code.indent():
129 code.writeline("pid = tl.program_id(0)")
130 code.writeline("offsets = pid * LOOP * BLOCK + tl.arange(0, BLOCK)")
132 # 1. Calculate inp_offsets and idx_offsets
133 code.writeline("for loop_iter in tl.static_range(LOOP):")
134 with code.indent():
135 code.writeline("mask = offsets < N")
136 code.writeline("cur_idx = offsets")
137 code.writeline("inp_offsets = tl.zeros((BLOCK, ), dtype=tl.int32)")
138 code.writeline("idx_offsets = tl.zeros((BLOCK, ), dtype=tl.int32)")
139 code.writeline("src_offsets = tl.zeros((BLOCK, ), dtype=tl.int32)")
140 for i in range(rank)[::-1]:
141 code.writeline(f"mod = cur_idx % shape_{i}")
142 code.writeline(f"inp_offsets += mod * inp_stride_{i}")
143 code.writeline(f"idx_offsets += mod * index_stride_{i}")
144 code.writeline(f"src_offsets += mod * src_stride_{i}")
145 if i != 0:
146 code.writeline(f"cur_idx = cur_idx // shape_{i}")
148 # 2. Use offsets to scatter
149 code.writeline(
150 "cur_src = tl.load(src_strided + src_offsets, mask=mask, other=0)"
151 )
152 code.writeline(
153 "cur_index = tl.load(index + idx_offsets, mask=mask, other=0)"
154 )
155 code.writeline("dim_offsets = cur_index * stride_dim")
156 code.writeline("inp_offsets += dim_offsets")
157 code.newline()
158 code.writeline(
159 "tl.atomic_add(out + inp_offsets, cur_src, mask=mask, sem='relaxed')"
160 )
161 code.writeline("offsets += BLOCK")
163 code.newline()
164 code.newline()
165 return code
168def parameter_for_wrapper() -> str:
169 # src_strided, index, inp, out, dim, M, N
170 parameters: List[str] = []
172 parameters.append("src_strided")
173 parameters.append("index")
174 parameters.append("inp")
175 parameters.append("out")
176 parameters.append("dim_size")
177 parameters.append("dim_stride")
178 parameters.append("N")
180 return ", ".join(parameters)
183def generate_destination_passing_wrapper(
184 rank: int,
185 wrapper_name: str,
186 kernel_name: str,
187 code: IndentedBuffer,
188) -> IndentedBuffer:
189 parameters: str = parameter_for_wrapper()
190 wrapper_signature: str = f"def {wrapper_name}({parameters}):"
191 code.writeline(wrapper_signature)
193 with code.indent():
194 code.writeline("inp_strides = list(inp.stride())")
195 code.writeline("index_strides = index.stride()")
196 code.writeline("src_strides = src_strided.stride()")
197 code.writeline("index_shapes = list(index.shape)")
198 code.writeline("inp_size_dim = dim_size")
199 code.writeline("stride_dim = dim_stride")
201 # kernel launch
202 code.writeline("grid = lambda meta: (")
203 with code.indent():
204 code.writeline('triton.cdiv(N, meta["BLOCK"] * meta["LOOP"]), ')
205 code.writeline(")")
206 kernel_launch: str = f"{kernel_name}[grid]("
207 code.writeline(kernel_launch)
208 with code.indent():
209 code.writeline("src_strided, index, inp, out, ")
210 if rank > 0:
211 s = ", ".join(f"inp_strides[{i}]" for i in range(rank))
212 code.writeline(f"{s},")
214 s = ", ".join(f"index_strides[{i}]" for i in range(rank))
215 code.writeline(f"{s},")
217 s = ", ".join(f"src_strides[{i}]" for i in range(rank))
218 code.writeline(f"{s},")
220 s = ", ".join(f"index_shapes[{i}]" for i in range(rank))
221 code.writeline(f"{s},")
223 code.writeline("inp_size_dim,")
224 code.writeline("stride_dim,")
225 code.writeline("N,")
227 code.writeline(")")
228 code.writeline("return out")
230 return code
233def generate_code(
234 inputs: Tuple[Any],
235 wrapper_name: str,
236 kernel_name: str,
237 code: IndentedBuffer,
238) -> IndentedBuffer:
239 # inputs: [src_strided, index, inp, out, dim, M, N]
240 shape = inputs[1].shape
241 rank = len(shape)
243 code = generate_imports(code)
244 code = generate_scatter_kernel(rank, kernel_name, code)
245 code = generate_destination_passing_wrapper(rank, wrapper_name, kernel_name, code)
246 return code
249class ScatterFunction:
250 def __init__(self):
251 self.pid = os.getpid()
252 self.overloads: Mapping[str, Callable] = {}
254 def __call__(self, *args, **kwargs):
255 key = f"{self.arg_key(*args)}"
256 if key in self.overloads:
257 overload = self.overloads[key]
258 else:
259 code = IndentedBuffer()
260 code = generate_code(
261 args,
262 "_scatter_add_wrapper",
263 "_scatter_add_jit_function",
264 code,
265 )
267 file_name = f"scatter_add_rank_{key}_pid_{self.pid}.py"
269 with open(code_cache_dir() / file_name, "wt", encoding="utf-8") as f:
270 f.write(code.getvalue())
272 # load
273 spec = importlib.util.spec_from_file_location(
274 f"_gen_module_rank_{key}_pid_{self.pid}",
275 f.name,
276 )
278 m = importlib.util.module_from_spec(spec)
279 spec.loader.exec_module(m)
280 overload = getattr(m, "_scatter_add_wrapper")
281 self.overloads[key] = overload
283 return overload(*args, **kwargs)
285 def arg_key(self, *args):
286 tensors = [item for item in args if torch.is_tensor(item)]
287 max_rank = max(item.ndim for item in tensors)
288 return max_rank
291_scatter_func = ScatterFunction()
294def scatter_add_0(inp, dim, index, src):
295 logger.debug("GEMS SCATTER_ADD_0")
296 dtype_convert = False
297 if inp.dtype == torch.float16 or inp.dtype == torch.bfloat16:
298 out = inp.to(torch.float32)
299 dtype_convert = True
300 else:
301 out = inp
303 src_strided = src.as_strided(index.shape, src.stride())
304 inp_restrided = restride_dim(inp, dim, index.shape)
305 dim_size = inp.size(dim)
306 dim_stride = inp.stride(dim)
307 N = index.numel()
309 _scatter_func(
310 src_strided,
311 index,
312 inp_restrided,
313 out,
314 dim_size,
315 dim_stride,
316 N,
317 )
318 if dtype_convert:
319 return inp.copy_(out.to(src.dtype))
320 return out
323def clip_tensor_to_shape(b, a):
324 target_shape = a.shape
325 slices = [
326 slice(0, min(b.shape[i], target_shape[i])) for i in range(len(target_shape))
327 ]
328 clipped_b = b[tuple(slices)]
329 return clipped_b
332def scatter_add_1(x, dim, index, src):
333 logger.debug("GEMS SCATTER_ADD_1")
334 index_dim_n = index.size(dim)
335 inp_dim_n = x.size(dim)
336 origin = x
337 if dim != x.ndim - 1:
338 x = dim_compress(x, dim)
339 if dim != x.ndim - 1:
340 src = dim_compress(src, dim)
341 if dim != x.ndim - 1:
342 index = dim_compress(index, dim)
344 all_elem = max(x.numel(), index.numel())
345 grid = lambda meta: (triton.cdiv(all_elem, meta["BLOCK_SIZE"] * meta["LOOP"]),)
347 dtype_convert = False
348 if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
349 dtype_convert = True
350 x = x.to(torch.float32)
352 scatter_add_kernel_1[grid](
353 index_dim_n, inp_dim_n, x, index, src, all_elem, BLOCK_SIZE=256, LOOP=1
354 )
355 if dim != x.ndim - 1:
356 order = [i for i in range(x.ndim - 1)]
357 order.insert(dim, x.ndim - 1)
358 if dtype_convert:
359 return origin.copy_(x.to(src.dtype).permute(order))
360 return x.permute(order)
361 else:
362 return x.to(src.dtype)
365def scatter_add_(x, dim, index, src):
366 assert x.dim() == index.dim() and x.dim() == src.dim(), "Invalid dim"
367 dim = dim % x.ndim
368 assert dim >= 0 and dim < x.dim(), "Invalid dim"
369 assert index.size(dim) <= src.size(dim), "Invalid src"
370 equal_count = 0
371 for d in range(x.dim()):
372 if d != dim:
373 assert index.size(d) <= x.size(d), "Invalid x"
374 if index.size(d) == x.size(d):
375 equal_count += 1
376 else:
377 if index.size(dim) >= x.size(dim):
378 equal_count += 1
380 if equal_count == x.dim() and index.shape == src.shape and dim == x.ndim - 1:
381 return scatter_add_1(x, dim, index, src)
382 if (index.shape == src.shape and index.shape == x.shape and dim != x.ndim - 1) or (
383 x.shape[0] == 4096 and x.numel() >= 9437184 and dim != x.ndim - 1
384 ):
385 if index.shape != src.shape:
386 src = clip_tensor_to_shape(src, index)
387 return scatter_add_1(x, dim, index, src)
388 else:
389 return scatter_add_0(x, dim, index, src)