Coverage for src/flag_gems/ops/index_add.py: 100%
157 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1import importlib
2import logging
3import os
4from typing import Any, Callable, List, Mapping, Tuple
6import torch
8from flag_gems.utils.code_cache import code_cache_dir
9from flag_gems.utils.code_utils import IndentedBuffer
11logger = logging.getLogger(__name__)
14def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
15 code.writeline("import triton")
16 code.writeline("import triton.language as tl")
17 code.writeline("from flag_gems.utils import libentry")
19 code.newline()
20 code.newline()
22 return code
25def generate_index_add_kernel(
26 rank: int,
27 kernel_name: str,
28 code: IndentedBuffer,
29) -> IndentedBuffer:
30 # the decorators
31 code.writeline("@libentry()")
32 code.writeline("@triton.jit")
34 # signature
35 code.writeline(f"def {kernel_name}(")
36 with code.indent():
37 if rank > 0:
38 code.writeline("index,")
39 code.writeline("src,")
40 code.writeline("out,")
41 code.writeline("N,")
42 code.writeline("inp_numel,")
43 code.writeline("inp_stride_dim,")
44 code.writeline("inp_shape_dim,")
45 code.writeline("src_shape_dim,")
46 code.writeline("delta,")
47 code.writeline("alpha,")
49 stride_args = ", ".join(f"src_stride_{i}: int" for i in range(rank))
50 code.writeline(f"{stride_args}, # stride for src")
52 shape_args = ", ".join(f"src_shape_{i}: int" for i in range(rank))
53 code.writeline(f"{shape_args}, # shape for src")
55 code.writeline("BLOCK_SIZE: tl.constexpr,")
57 code.writeline("):")
59 # Kernel Code
60 with code.indent():
61 code.writeline("pid = tl.program_id(axis=0)")
62 code.writeline("offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)")
63 code.writeline("mask = offsets < N")
65 for i in range(rank - 1, -1, -1):
66 code.writeline(f"src_offset{i} = offsets % src_shape_{i}")
67 code.writeline(f"offsets = offsets // src_shape_{i}")
68 code.newline()
69 comp = [f"src_offset{i} * src_stride_{i}" for i in range(rank)]
70 code.writeline(f"src_offset = {' + '.join(comp)}")
72 code.writeline("pre_cal = (inp_stride_dim * src_shape_dim)")
74 # index add
75 code.writeline("pre_idx = (src_offset // pre_cal).to(tl.int64)")
76 code.writeline(
77 "dim_idx = (src_offset % pre_cal // inp_stride_dim).to(tl.int64)"
78 )
79 code.writeline(
80 "src_dim_idx = (tl.load(index + dim_idx, mask=mask, other=0)).to(tl.int64)"
81 )
82 code.writeline(
83 'assert src_dim_idx >= 0 and src_dim_idx < inp_shape_dim, "0 <= index < self.size(dim)"'
84 )
85 code.writeline(
86 "input_idx = (src_offset + (delta * pre_idx + src_dim_idx - dim_idx) * inp_stride_dim).to(tl.int64)"
87 )
89 code.writeline("input_mask = input_idx < inp_numel")
90 code.writeline(
91 "add_on = tl.load(src + src_offset, mask=mask, other=0) * alpha"
92 )
93 code.writeline(
94 "tl.atomic_add(out + input_idx, add_on, mask=input_mask, sem='relaxed')"
95 )
96 # TODO: tl.atomic_add doesn't support bfloat16! The following method may be unsafe.
97 # code.writeline("cur_out = tl.load(out + input_idx, mask=input_mask)")
98 # code.writeline("tl.store(out + input_idx, cur_out + add_on, mask=input_mask)")
100 code.newline()
101 code.newline()
102 return code
105def parameter_for_wrapper() -> str:
106 # out, index, src, dim, inp_stride_dim, src_shape_dim, delta, N, inp.numel(), alpha
107 parameters: List[str] = []
108 parameters.append("out")
109 parameters.append("index")
110 parameters.append("src")
111 parameters.append("dim")
112 parameters.append("inp_stride_dim")
113 parameters.append("inp_shape_dim")
114 parameters.append("src_shape_dim")
115 parameters.append("delta")
116 parameters.append("N")
117 parameters.append("inp_numel")
118 parameters.append("alpha")
120 return ", ".join(parameters)
123def generate_destination_passing_wrapper(
124 rank: int,
125 wrapper_name: str,
126 kernel_name: str,
127 code: IndentedBuffer,
128) -> IndentedBuffer:
129 parameters: str = parameter_for_wrapper()
130 wrapper_signature: str = f"def {wrapper_name} ({parameters}):"
131 code.writeline(wrapper_signature)
133 with code.indent():
134 code.writeline("src_strides = list(src.stride())")
135 code.writeline("src_shapes = list(src.shape)")
137 # kernel launch
138 code.writeline("BLOCK_SIZE = 128") # BLOCK_SIZE setting
139 code.writeline("grid = (triton.cdiv(N, BLOCK_SIZE),)")
140 kernel_launch: str = f"{kernel_name}[grid]("
141 code.writeline(kernel_launch)
142 with code.indent():
143 code.writeline(
144 "index, src, out, N, inp_numel, inp_stride_dim, inp_shape_dim, src_shape_dim, delta, alpha, "
145 )
146 if rank > 0:
147 s = ", ".join(f"src_strides[{i}]" for i in range(rank))
148 code.writeline(f"{s},")
150 s = ", ".join(f"src_shapes[{i}]" for i in range(rank))
151 code.writeline(f"{s},")
152 code.writeline("BLOCK_SIZE=BLOCK_SIZE")
153 code.writeline(")")
154 code.writeline("return out")
156 return code
159def generate_code(
160 inputs: Tuple[Any],
161 wrapper_name: str,
162 kernel_name: str,
163 code: IndentedBuffer,
164) -> IndentedBuffer:
165 # inputs: [out, index, src, dim, inp_stride_dim, inp_shape_dim, src_shape_dim, delta, N, inp.numel(), alpha]
166 shape = inputs[2].shape
167 rank = len(shape)
169 code = generate_imports(code)
170 code = generate_index_add_kernel(rank, kernel_name, code)
171 code = generate_destination_passing_wrapper(rank, wrapper_name, kernel_name, code)
172 return code
175class IndexAddFunction:
176 def __init__(self):
177 self.pid = os.getpid()
178 self.overloads: Mapping[str, Callable] = {}
180 def __call__(self, *args, **kwargs):
181 key = f"{self.arg_key(*args)}"
182 if key in self.overloads:
183 overload = self.overloads[key]
184 else:
185 code = IndentedBuffer()
186 code = generate_code(
187 args,
188 "_index_add_wrapper",
189 "_index_add_jit_function",
190 code,
191 )
193 file_name = f"index_add_rank_{key}_pid_{self.pid}.py"
195 with open(code_cache_dir() / file_name, "wt", encoding="utf-8") as f:
196 f.write(code.getvalue())
198 # load
199 spec = importlib.util.spec_from_file_location(
200 f"_gen_module_rank_{key}_pid_{self.pid}",
201 f.name,
202 )
204 m = importlib.util.module_from_spec(spec)
205 spec.loader.exec_module(m)
206 overload = getattr(m, "_index_add_wrapper")
207 self.overloads[key] = overload
209 return overload(*args, **kwargs)
211 def arg_key(self, *args):
212 tensors = [item for item in args if torch.is_tensor(item)]
213 max_rank = max(item.ndim for item in tensors)
214 return max_rank
217_index_add_func = IndexAddFunction()
220def index_add(inp, dim, index, src, alpha=1):
221 logger.debug("GEMS INDEX ADD")
222 assert ((0 <= index) * (index < inp.size(dim))).equal(
223 torch.ones(tuple(index.shape), dtype=torch.bool, device=inp.device)
224 ), "0 <= index < self.size(dim)"
225 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
226 assert index.numel() == src.size(
227 dim
228 ), "The dimth dimension of source must have the same size as the length of index"
229 assert (
230 inp.ndim == src.ndim
231 ), "Self and source should have the same number of dimensions"
232 assert (
233 ((inp.size(i) == src.size(i)) or i == dim) for i in range(0, inp.ndim)
234 ), "src.size(d) == self.size(d) for all dimensions d != dim"
236 out = inp.clone()
238 dim %= inp.ndim
239 inp_stride_dim = inp.stride(dim)
240 src_shape_dim = src.size(dim)
241 inp_shape_dim = inp.size(dim)
242 delta = inp.size(dim) - src_shape_dim
243 N = src.numel()
245 _index_add_func(
246 out,
247 index,
248 src,
249 dim,
250 inp_stride_dim,
251 inp_shape_dim,
252 src_shape_dim,
253 delta,
254 N,
255 inp.numel(),
256 alpha,
257 )
258 return out
261def index_add_(inp, dim, index, src, alpha=1):
262 logger.debug("GEMS INDEX ADD_")
263 assert ((0 <= index) * (index < inp.size(dim))).equal(
264 torch.ones(tuple(index.shape), dtype=torch.bool, device=inp.device)
265 ), "0 <= index < self.size(dim)"
266 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
267 assert index.numel() == src.size(
268 dim
269 ), "The dimth dimension of source must have the same size as the length of index"
270 assert (
271 inp.ndim == src.ndim
272 ), "Self and source should have the same number of dimensions"
273 assert (
274 ((inp.size(i) == src.size(i)) or i == dim) for i in range(0, inp.ndim)
275 ), "src.size(d) == self.size(d) for all dimensions d != dim"
277 dim %= inp.ndim
278 inp_stride_dim = inp.stride(dim)
279 src_shape_dim = src.size(dim)
280 inp_shape_dim = inp.size(dim)
281 delta = inp.size(dim) - src_shape_dim
282 N = src.numel()
284 _index_add_func(
285 inp,
286 index,
287 src,
288 dim,
289 inp_stride_dim,
290 inp_shape_dim,
291 src_shape_dim,
292 delta,
293 N,
294 inp.numel(),
295 alpha,
296 )
297 return inp