Coverage for src/flag_gems/runtime/backend/_cambricon/ops/index_add.py: 0%
193 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1import importlib
2import logging
3import os
4from typing import Any, Callable, List, Mapping, Tuple
6import torch
7import triton
8import triton.language as tl
10from flag_gems.utils import dim_compress, libentry
11from flag_gems.utils.code_cache import code_cache_dir
12from flag_gems.utils.code_utils import IndentedBuffer
14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
17def cfggen():
18 block_m = [1, 2, 4, 8]
19 block_n = [128, 1024, 2048, 4096]
20 configs = [
21 triton.Config({"BLOCK_M": m, "BLOCK_N": n}, num_warps=1)
22 for m in block_m
23 for n in block_n
24 ]
25 return configs
28@libentry()
29@triton.autotune(configs=cfggen(), key=["M", "N"])
30@triton.jit
31def index_add_kernel(
32 inp,
33 out,
34 index,
35 src,
36 M,
37 N,
38 alpha,
39 inp_len,
40 BLOCK_M: tl.constexpr,
41 BLOCK_N: tl.constexpr,
42):
43 pid_x = tl.program_id(axis=0)
44 pid_y = tl.program_id(axis=1)
45 rows_offsets = pid_x * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
46 cols_offsets = pid_y * BLOCK_N + tl.arange(0, BLOCK_N)
48 rows_mask = rows_offsets < M
49 index_mask = cols_offsets < N
50 block_mask = rows_mask and index_mask
52 cur_indices = tl.load(index + cols_offsets, mask=index_mask, other=0)
53 inp_off = rows_offsets * inp_len + cur_indices[None, :]
54 cur_inp = tl.load(inp + inp_off, mask=block_mask, other=0.0)
55 src_off = rows_offsets * N + cols_offsets[None, :]
56 cur_src = tl.load(src + src_off, mask=block_mask, other=0.0)
57 cur_inp += alpha * cur_src
59 tl.store(out + inp_off, cur_inp, mask=block_mask)
62def index_add(inp, dim, index, src, alpha=1):
63 logger.debug("GEMS_CAMBRICON INDEX ADD")
64 assert ((0 <= index) * (index < inp.size(dim))).equal(
65 torch.ones(tuple(index.shape), dtype=torch.bool, device="cuda")
66 ), "0 <= index < self.size(dim)"
67 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
68 assert index.numel() == src.size(
69 dim
70 ), "The dimth dimension of source must have the same size as the length of index"
71 assert (
72 inp.ndim == src.ndim
73 ), "Self and source should have the same number of dimensions"
74 assert (
75 ((inp.size(i) == src.size(i)) or i == dim) for i in range(0, inp.ndim)
76 ), "src.size(d) == self.size(d) for all dimensions d != dim"
78 inp = inp.contiguous()
79 index = index.contiguous()
80 src = src.contiguous()
82 dim = dim % inp.ndim
83 inp_len = inp.size(dim)
84 N = index.numel()
85 M = src.numel() // N
86 fine_dim = inp.ndim - 1
87 if dim != fine_dim:
88 inp = dim_compress(inp, dim)
89 src = dim_compress(src, dim)
90 out = inp.clone()
92 grid = lambda meta: (
93 triton.cdiv(M, meta["BLOCK_M"]),
94 triton.cdiv(N, meta["BLOCK_N"]),
95 )
96 index_add_kernel[grid](inp, out, index, src, M, N, alpha, inp_len)
97 if dim != fine_dim:
98 order = [i for i in range(out.ndim - 1)]
99 order.insert(dim, fine_dim)
100 return out.permute(order).contiguous()
101 else:
102 return out
105def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
106 code.writeline("import triton")
107 code.writeline("import triton.language as tl")
108 code.writeline("from flag_gems.utils import libentry")
110 code.newline()
111 code.newline()
113 return code
116def generate_index_add_kernel(
117 rank: int,
118 kernel_name: str,
119 code: IndentedBuffer,
120) -> IndentedBuffer:
121 # the decorators
122 code.writeline("@libentry()")
123 code.writeline("@triton.jit")
125 # signature
126 code.writeline(f"def {kernel_name}(")
127 with code.indent():
128 if rank > 0:
129 code.writeline("index,")
130 code.writeline("src,")
131 code.writeline("out,")
132 code.writeline("N,")
133 code.writeline("inp_numel,")
134 code.writeline("inp_stride_dim,")
135 code.writeline("inp_shape_dim,")
136 code.writeline("src_shape_dim,")
137 code.writeline("delta,")
138 code.writeline("alpha,")
140 stride_args = ", ".join(f"src_stride_{i}: int" for i in range(rank))
141 code.writeline(f"{stride_args}, # stride for src")
143 shape_args = ", ".join(f"src_shape_{i}: int" for i in range(rank))
144 code.writeline(f"{shape_args}, # shape for src")
146 code.writeline("BLOCK_SIZE: tl.constexpr,")
148 code.writeline("):")
150 # Kernel Code
151 with code.indent():
152 code.writeline("pid = tl.program_id(axis=0)")
153 code.writeline("offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)")
154 code.writeline("mask = offsets < N")
156 for i in range(rank - 1, -1, -1):
157 code.writeline(f"src_offset{i} = offsets % src_shape_{i}")
158 code.writeline(f"offsets = offsets // src_shape_{i}")
159 code.newline()
160 comp = [f"src_offset{i} * src_stride_{i}" for i in range(rank)]
161 code.writeline(f"src_offset = {' + '.join(comp)}")
163 code.writeline("pre_cal = (inp_stride_dim * src_shape_dim)")
165 # index add
166 code.writeline("pre_idx = (src_offset // pre_cal).to(tl.int64)")
167 code.writeline(
168 "dim_idx = (src_offset % pre_cal // inp_stride_dim).to(tl.int64)"
169 )
170 code.writeline(
171 "src_dim_idx = (tl.load(index + dim_idx, mask=mask, other=0)).to(tl.int64)"
172 )
173 code.writeline(
174 'assert src_dim_idx >= 0 and src_dim_idx < inp_shape_dim, "0 <= index < self.size(dim)"'
175 )
176 code.writeline(
177 "input_idx = (src_offset + (delta * pre_idx + src_dim_idx - dim_idx) * inp_stride_dim).to(tl.int64)"
178 )
180 code.writeline("input_mask = input_idx < inp_numel")
181 code.writeline(
182 "add_on = tl.load(src + src_offset, mask=mask, other=0) * alpha"
183 )
184 code.writeline(
185 "tl.atomic_add(out + input_idx, add_on, mask=input_mask, sem='relaxed')"
186 )
187 # TODO: tl.atomic_add doesn't support bfloat16! The following method may be unsafe.
188 # code.writeline("cur_out = tl.load(out + input_idx, mask=input_mask)")
189 # code.writeline("tl.store(out + input_idx, cur_out + add_on, mask=input_mask)")
191 code.newline()
192 code.newline()
193 return code
196def parameter_for_wrapper() -> str:
197 # out, index, src, dim, inp_stride_dim, src_shape_dim, delta, N, inp.numel(), alpha
198 parameters: List[str] = []
199 parameters.append("out")
200 parameters.append("index")
201 parameters.append("src")
202 parameters.append("dim")
203 parameters.append("inp_stride_dim")
204 parameters.append("inp_shape_dim")
205 parameters.append("src_shape_dim")
206 parameters.append("delta")
207 parameters.append("N")
208 parameters.append("inp_numel")
209 parameters.append("alpha")
211 return ", ".join(parameters)
214def generate_destination_passing_wrapper(
215 rank: int,
216 wrapper_name: str,
217 kernel_name: str,
218 code: IndentedBuffer,
219) -> IndentedBuffer:
220 parameters: str = parameter_for_wrapper()
221 wrapper_signature: str = f"def {wrapper_name} ({parameters}):"
222 code.writeline(wrapper_signature)
224 with code.indent():
225 code.writeline("src_strides = list(src.stride())")
226 code.writeline("src_shapes = list(src.shape)")
228 # kernel launch
229 code.writeline("BLOCK_SIZE = 640") # BLOCK_SIZE setting
230 code.writeline("grid = (triton.cdiv(N, BLOCK_SIZE),)")
231 kernel_launch: str = f"{kernel_name}[grid]("
232 code.writeline(kernel_launch)
233 with code.indent():
234 code.writeline(
235 "index, src, out, N, inp_numel, inp_stride_dim, inp_shape_dim, src_shape_dim, delta, alpha, "
236 )
237 if rank > 0:
238 s = ", ".join(f"src_strides[{i}]" for i in range(rank))
239 code.writeline(f"{s},")
241 s = ", ".join(f"src_shapes[{i}]" for i in range(rank))
242 code.writeline(f"{s},")
243 code.writeline("BLOCK_SIZE=BLOCK_SIZE")
244 code.writeline(")")
245 code.writeline("return out")
247 return code
250def generate_code(
251 inputs: Tuple[Any],
252 wrapper_name: str,
253 kernel_name: str,
254 code: IndentedBuffer,
255) -> IndentedBuffer:
256 # inputs: [out, index, src, dim, inp_stride_dim, inp_shape_dim, src_shape_dim, delta, N, inp.numel(), alpha]
257 shape = inputs[2].shape
258 rank = len(shape)
260 code = generate_imports(code)
261 code = generate_index_add_kernel(rank, kernel_name, code)
262 code = generate_destination_passing_wrapper(rank, wrapper_name, kernel_name, code)
263 return code
266class IndexAddFunction:
267 def __init__(self):
268 self.pid = os.getpid()
269 self.overloads: Mapping[str, Callable] = {}
271 def __call__(self, *args, **kwargs):
272 key = f"{self.arg_key(*args)}"
273 if key in self.overloads:
274 overload = self.overloads[key]
275 else:
276 code = IndentedBuffer()
277 code = generate_code(
278 args,
279 "_index_add_wrapper",
280 "_index_add_jit_function",
281 code,
282 )
284 file_name = f"index_add_rank_{key}_pid_{self.pid}.py"
286 with open(code_cache_dir() / file_name, "wt", encoding="utf-8") as f:
287 f.write(code.getvalue())
289 # load
290 spec = importlib.util.spec_from_file_location(
291 f"_gen_module_rank_{key}_pid_{self.pid}",
292 f.name,
293 )
295 m = importlib.util.module_from_spec(spec)
296 spec.loader.exec_module(m)
297 overload = getattr(m, "_index_add_wrapper")
298 self.overloads[key] = overload
300 return overload(*args, **kwargs)
302 def arg_key(self, *args):
303 tensors = [item for item in args if torch.is_tensor(item)]
304 max_rank = max(item.ndim for item in tensors)
305 return max_rank
308_index_add_func = IndexAddFunction()
311def index_add_(inp, dim, index, src, alpha=1):
312 logger.debug("GEMS_CAMBRICON INDEX ADD_")
313 assert ((0 <= index) * (index < inp.size(dim))).equal(
314 torch.ones(tuple(index.shape), dtype=torch.bool, device=inp.device)
315 ), "0 <= index < self.size(dim)"
316 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
317 assert index.numel() == src.size(
318 dim
319 ), "The dimth dimension of source must have the same size as the length of index"
320 assert (
321 inp.ndim == src.ndim
322 ), "Self and source should have the same number of dimensions"
323 assert (
324 ((inp.size(i) == src.size(i)) or i == dim) for i in range(0, inp.ndim)
325 ), "src.size(d) == self.size(d) for all dimensions d != dim"
327 dim %= inp.ndim
328 inp_stride_dim = inp.stride(dim)
329 src_shape_dim = src.size(dim)
330 inp_shape_dim = inp.size(dim)
331 delta = inp.size(dim) - src_shape_dim
332 N = src.numel()
334 _index_add_func(
335 inp,
336 index,
337 src,
338 dim,
339 inp_stride_dim,
340 inp_shape_dim,
341 src_shape_dim,
342 delta,
343 N,
344 inp.numel(),
345 alpha,
346 )
347 return inp