Coverage for src/flag_gems/runtime/backend/_ascend/ops/index_add.py: 0%
145 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
1import importlib
2import logging
3import os
4from typing import Any, Callable, List, Mapping, Tuple
6import torch
8from flag_gems.utils.code_cache import code_cache_dir
9from flag_gems.utils.code_utils import IndentedBuffer
11logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
14def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
15 code.writeline("import triton")
16 code.writeline("import triton.language as tl")
17 code.writeline("from flag_gems.utils import libentry")
18 code.writeline("from flag_gems import runtime")
19 code.newline()
20 code.newline()
22 return code
25def generate_index_add_kernel(
26 rank: int,
27 kernel_name: str,
28 code: IndentedBuffer,
29) -> IndentedBuffer:
30 # the decorators
31 code.writeline("@libentry()")
32 code.writeline(
33 '@triton.autotune(configs=runtime.get_tuned_config("index_add"), key=["BLOCK_SIZE"])'
34 )
35 code.writeline("@triton.jit")
37 # signature
38 code.writeline(f"def {kernel_name}(")
39 with code.indent():
40 if rank > 0:
41 code.writeline("index,")
42 code.writeline("src,")
43 code.writeline("out,")
44 code.writeline("N,")
45 code.writeline("inp_numel,")
46 code.writeline("inp_stride_dim,")
47 code.writeline("inp_shape_dim,")
48 code.writeline("src_shape_dim,")
49 code.writeline("delta,")
50 code.writeline("alpha,")
52 stride_args = ", ".join(f"src_stride_{i}: int" for i in range(rank))
53 code.writeline(f"{stride_args}, # stride for src")
55 shape_args = ", ".join(f"src_shape_{i}: int" for i in range(rank))
56 code.writeline(f"{shape_args}, # shape for src")
58 code.writeline("BLOCK_SIZE: tl.constexpr,")
60 code.writeline("):")
62 # Kernel Code
63 with code.indent():
64 code.writeline("pid = tl.program_id(axis=0)")
65 code.writeline("offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)")
66 code.writeline("mask = offsets < N")
68 for i in range(rank - 1, -1, -1):
69 code.writeline(f"src_offset{i} = offsets % src_shape_{i}")
70 code.writeline(f"offsets = offsets // src_shape_{i}")
71 code.newline()
72 comp = [f"src_offset{i} * src_stride_{i}" for i in range(rank)]
73 code.writeline(f"src_offset = {' + '.join(comp)}")
75 code.writeline("pre_cal = (inp_stride_dim * src_shape_dim)")
77 # index add
78 code.writeline("pre_idx = (src_offset // pre_cal).to(tl.int64)")
79 code.writeline(
80 "dim_idx = (src_offset % pre_cal // inp_stride_dim).to(tl.int64)"
81 )
82 code.writeline(
83 "src_dim_idx = (tl.load(index + dim_idx, mask=mask, other=0)).to(tl.int64)"
84 )
85 code.writeline(
86 'assert src_dim_idx >= 0 and src_dim_idx < inp_shape_dim, "0 <= index < self.size(dim)"'
87 )
88 code.writeline(
89 "input_idx = (src_offset + (delta * pre_idx + src_dim_idx - dim_idx) * inp_stride_dim).to(tl.int64)"
90 )
92 code.writeline("input_mask = input_idx < inp_numel")
93 code.writeline(
94 "add_on = tl.load(src + src_offset, mask=mask, other=0) * alpha"
95 )
96 code.writeline(
97 "tl.atomic_add(out + input_idx, add_on, mask=input_mask, sem='relaxed')"
98 )
99 # TODO: tl.atomic_add doesn't support bfloat16! The following method may be unsafe.
100 # code.writeline("cur_out = tl.load(out + input_idx, mask=input_mask)")
101 # code.writeline("tl.store(out + input_idx, cur_out + add_on, mask=input_mask)")
103 code.newline()
104 code.newline()
105 return code
108def parameter_for_wrapper() -> str:
109 # out, index, src, dim, inp_stride_dim, src_shape_dim, delta, N, inp.numel(), alpha
110 parameters: List[str] = []
111 parameters.append("out")
112 parameters.append("index")
113 parameters.append("src")
114 parameters.append("dim")
115 parameters.append("inp_stride_dim")
116 parameters.append("inp_shape_dim")
117 parameters.append("src_shape_dim")
118 parameters.append("delta")
119 parameters.append("N")
120 parameters.append("inp_numel")
121 parameters.append("alpha")
123 return ", ".join(parameters)
126def generate_destination_passing_wrapper(
127 rank: int,
128 wrapper_name: str,
129 kernel_name: str,
130 code: IndentedBuffer,
131) -> IndentedBuffer:
132 parameters: str = parameter_for_wrapper()
133 wrapper_signature: str = f"def {wrapper_name} ({parameters}):"
134 code.writeline(wrapper_signature)
136 with code.indent():
137 code.writeline("src_strides = list(src.stride())")
138 code.writeline("src_shapes = list(src.shape)")
140 # kernel launch
141 code.writeline("grid = lambda meta: (")
142 with code.indent():
143 code.writeline("triton.cdiv(N, meta['BLOCK_SIZE']), ")
144 code.writeline(")")
145 kernel_launch: str = f"{kernel_name}[grid]("
146 code.writeline(kernel_launch)
147 with code.indent():
148 code.writeline(
149 "index, src, out, N, inp_numel, inp_stride_dim, inp_shape_dim, src_shape_dim, delta, alpha, "
150 )
151 if rank > 0:
152 s = ", ".join(f"src_strides[{i}]" for i in range(rank))
153 code.writeline(f"{s},")
155 s = ", ".join(f"src_shapes[{i}]" for i in range(rank))
156 code.writeline(f"{s},")
157 # code.writeline("BLOCK_SIZE=BLOCK_SIZE")
158 code.writeline(")")
159 code.writeline("return out")
161 return code
164def generate_code(
165 inputs: Tuple[Any],
166 wrapper_name: str,
167 kernel_name: str,
168 code: IndentedBuffer,
169) -> IndentedBuffer:
170 # inputs: [out, index, src, dim, inp_stride_dim, inp_shape_dim, src_shape_dim, delta, N, inp.numel(), alpha]
171 shape = inputs[2].shape
172 rank = len(shape)
174 code = generate_imports(code)
175 code = generate_index_add_kernel(rank, kernel_name, code)
176 code = generate_destination_passing_wrapper(rank, wrapper_name, kernel_name, code)
177 return code
180class IndexAddFunction:
181 def __init__(self):
182 self.pid = os.getpid()
183 self.overloads: Mapping[str, Callable] = {}
185 def __call__(self, *args, **kwargs):
186 key = f"{self.arg_key(*args)}"
187 if key in self.overloads:
188 overload = self.overloads[key]
189 else:
190 code = IndentedBuffer()
191 code = generate_code(
192 args,
193 "_index_add_wrapper",
194 "_index_add_jit_function",
195 code,
196 )
198 file_name = f"index_add_rank_{key}_pid_{self.pid}.py"
200 with open(code_cache_dir() / file_name, "wt", encoding="utf-8") as f:
201 f.write(code.getvalue())
203 # load
204 spec = importlib.util.spec_from_file_location(
205 f"_gen_module_rank_{key}_pid_{self.pid}",
206 f.name,
207 )
209 m = importlib.util.module_from_spec(spec)
210 spec.loader.exec_module(m)
211 overload = getattr(m, "_index_add_wrapper")
212 self.overloads[key] = overload
214 return overload(*args, **kwargs)
216 def arg_key(self, *args):
217 tensors = [item for item in args if torch.is_tensor(item)]
218 max_rank = max(item.ndim for item in tensors)
219 return max_rank
222_index_add_func = IndexAddFunction()
225def index_add(inp, dim, index, src, alpha=1):
226 logger.debug("GEMS_ASCEND INDEX ADD")
227 assert ((0 <= index).to(torch.int8) * (index < inp.size(dim))).equal(
228 torch.ones(tuple(index.shape), dtype=torch.int8, device=inp.device)
229 ), "0 <= index < self.size(dim)"
230 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
231 assert index.numel() == src.size(
232 dim
233 ), "The dimth dimension of source must have the same size as the length of index"
234 assert (
235 inp.ndim == src.ndim
236 ), "Self and source should have the same number of dimensions"
237 assert (
238 ((inp.size(i) == src.size(i)) or i == dim) for i in range(0, inp.ndim)
239 ), "src.size(d) == self.size(d) for all dimensions d != dim"
241 out = inp.clone()
243 dim %= inp.ndim
244 inp_stride_dim = inp.stride(dim)
245 src_shape_dim = src.size(dim)
246 inp_shape_dim = inp.size(dim)
247 delta = inp.size(dim) - src_shape_dim
248 N = src.numel()
250 _index_add_func(
251 out,
252 index,
253 src,
254 dim,
255 inp_stride_dim,
256 inp_shape_dim,
257 src_shape_dim,
258 delta,
259 N,
260 inp.numel(),
261 alpha,
262 )
263 return out