Coverage for src/flag_gems/ops/scatter_add_.py: 82%

255 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1import importlib 

2import logging 

3import os 

4from typing import Any, Callable, List, Mapping, Tuple 

5 

6import torch 

7import triton 

8import triton.language as tl 

9 

10from flag_gems.utils.code_cache import code_cache_dir 

11from flag_gems.utils.code_utils import IndentedBuffer 

12from flag_gems.utils.shape_utils import restride_dim 

13 

14from ..utils import dim_compress 

15 

16logger = logging.getLogger(__name__) 

17 

18 

19@triton.jit 

20def scatter_add_kernel_1( 

21 index_dim_n, 

22 inp_dim_n, 

23 out_ptr, 

24 index_ptr, 

25 src_ptr, 

26 n_elements, 

27 BLOCK_SIZE: tl.constexpr, 

28 LOOP: tl.constexpr, 

29): 

30 pid = tl.program_id(0) 

31 block_start = pid * BLOCK_SIZE * LOOP 

32 arange = tl.arange(0, BLOCK_SIZE) 

33 offsets = block_start + arange 

34 mask = offsets < n_elements 

35 for loop_iter in tl.static_range(LOOP): 

36 src_index_offsets = block_start + arange 

37 src_tensor = tl.load(src_ptr + src_index_offsets, mask=mask, other=0) 

38 index_tensor = tl.load(index_ptr + src_index_offsets, mask=mask, other=0) 

39 out_offsets = src_index_offsets // index_dim_n * inp_dim_n + index_tensor 

40 tl.atomic_add(out_ptr + out_offsets, src_tensor, mask=mask, sem="relaxed") 

41 block_start += BLOCK_SIZE 

42 

43 

44def generate_imports(code: IndentedBuffer) -> IndentedBuffer: 

45 code.writeline("import torch") 

46 code.writeline("import triton") 

47 code.writeline("import triton.language as tl") 

48 code.newline() 

49 code.writeline("from flag_gems.utils import libentry") 

50 code.writeline("from flag_gems import runtime") 

51 code.writeline("import flag_gems") 

52 code.newline() 

53 code.newline() 

54 return code 

55 

56 

57def generate_scatter_kernel( 

58 rank: int, 

59 kernel_name: str, 

60 code: IndentedBuffer, 

61) -> IndentedBuffer: 

62 # make the inlined function visible in the context 

63 code.newline() 

64 

65 # the autotune function 

66 code.writeline("def heur_block(args):") 

67 with code.indent(): 

68 code.writeline("if(flag_gems.vendor_name in ['metax', 'iluvatar']):") 

69 with code.indent(): 

70 code.writeline("return 256") 

71 code.writeline("return 128") 

72 code.newline() 

73 code.newline() 

74 

75 code.writeline("def loop_count(args):") 

76 with code.indent(): 

77 code.writeline("return 1") 

78 code.newline() 

79 code.newline() 

80 

81 # the decorators 

82 code.writeline("@libentry()") 

83 code.writeline("@triton.heuristics(") 

84 with code.indent(): 

85 code.writeline("{") 

86 with code.indent(): 

87 code.writeline('"BLOCK": heur_block,') 

88 code.writeline('"LOOP": loop_count,') 

89 code.writeline("}") 

90 code.writeline(")") 

91 inp_stride_vars = ",".join(f"'inp_stride_{i}'" for i in range(rank)) 

92 index_stride_vars = ",".join(f"'index_stride_{i}'" for i in range(rank)) 

93 src_stride_vars = ",".join(f"'src_stride_{i}'" for i in range(rank)) 

94 shape_vars = ",".join(f"'shape_{i}'" for i in range(rank)) 

95 code.writeline( 

96 f"@triton.jit(do_not_specialize=['N','stride_dim','inp_size_dim'," 

97 f"{inp_stride_vars},{index_stride_vars},{src_stride_vars},{shape_vars}])" 

98 ) 

99 

100 # signature 

101 code.writeline(f"def {kernel_name}(") 

102 with code.indent(): 

103 if rank > 0: 

104 code.writeline("src_strided,") 

105 code.writeline("index,") 

106 code.writeline("inp,") 

107 code.writeline("out,") 

108 

109 stride_args = ", ".join(f"inp_stride_{i}: int" for i in range(rank)) 

110 code.writeline(f"{stride_args}, # stride for inp") 

111 

112 stride_args = ", ".join(f"index_stride_{i}: int" for i in range(rank)) 

113 code.writeline(f"{stride_args}, # stride for index") 

114 

115 stride_args = ", ".join(f"src_stride_{i}: int" for i in range(rank)) 

116 code.writeline(f"{stride_args}, # stride for src") 

117 

118 shape_args = ", ".join(f"shape_{i}: int" for i in range(rank)) 

119 code.writeline(f"{shape_args}, # shape") 

120 code.writeline("inp_size_dim,") 

121 code.writeline("stride_dim,") 

122 code.writeline("N,") 

123 code.writeline("BLOCK: tl.constexpr,") 

124 code.writeline("LOOP: tl.constexpr,") 

125 

126 code.writeline("):") 

127 

128 # Kernel Code 

129 with code.indent(): 

130 code.writeline("pid = tl.program_id(0)") 

131 code.writeline("offsets = pid * LOOP * BLOCK + tl.arange(0, BLOCK)") 

132 

133 # 1. Calculate inp_offsets and idx_offsets 

134 code.writeline("for loop_iter in tl.static_range(LOOP):") 

135 with code.indent(): 

136 code.writeline("mask = offsets < N") 

137 code.writeline("cur_idx = offsets") 

138 code.writeline("inp_offsets = tl.zeros((BLOCK, ), dtype=tl.int32)") 

139 code.writeline("idx_offsets = tl.zeros((BLOCK, ), dtype=tl.int32)") 

140 code.writeline("src_offsets = tl.zeros((BLOCK, ), dtype=tl.int32)") 

141 for i in range(rank)[::-1]: 

142 code.writeline(f"mod = cur_idx % shape_{i}") 

143 code.writeline(f"inp_offsets += mod * inp_stride_{i}") 

144 code.writeline(f"idx_offsets += mod * index_stride_{i}") 

145 code.writeline(f"src_offsets += mod * src_stride_{i}") 

146 if i != 0: 

147 code.writeline(f"cur_idx = cur_idx // shape_{i}") 

148 

149 # 2. Use offsets to scatter 

150 code.writeline( 

151 "cur_src = tl.load(src_strided + src_offsets, mask=mask, other=0)" 

152 ) 

153 code.writeline( 

154 "cur_index = tl.load(index + idx_offsets, mask=mask, other=0)" 

155 ) 

156 code.writeline("dim_offsets = cur_index * stride_dim") 

157 code.writeline("inp_offsets += dim_offsets") 

158 code.newline() 

159 code.writeline( 

160 "tl.atomic_add(out + inp_offsets, cur_src, mask=mask, sem='relaxed')" 

161 ) 

162 code.writeline("offsets += BLOCK") 

163 

164 code.newline() 

165 code.newline() 

166 return code 

167 

168 

169def parameter_for_wrapper() -> str: 

170 # src_strided, index, inp, out, dim, M, N 

171 parameters: List[str] = [] 

172 

173 parameters.append("src_strided") 

174 parameters.append("index") 

175 parameters.append("inp") 

176 parameters.append("out") 

177 parameters.append("dim_size") 

178 parameters.append("dim_stride") 

179 parameters.append("N") 

180 

181 return ", ".join(parameters) 

182 

183 

184def generate_destination_passing_wrapper( 

185 rank: int, 

186 wrapper_name: str, 

187 kernel_name: str, 

188 code: IndentedBuffer, 

189) -> IndentedBuffer: 

190 parameters: str = parameter_for_wrapper() 

191 wrapper_signature: str = f"def {wrapper_name}({parameters}):" 

192 code.writeline(wrapper_signature) 

193 

194 with code.indent(): 

195 code.writeline("inp_strides = list(inp.stride())") 

196 code.writeline("index_strides = index.stride()") 

197 code.writeline("src_strides = src_strided.stride()") 

198 code.writeline("index_shapes = list(index.shape)") 

199 code.writeline("inp_size_dim = dim_size") 

200 code.writeline("stride_dim = dim_stride") 

201 

202 # kernel launch 

203 code.writeline("grid = lambda meta: (") 

204 with code.indent(): 

205 code.writeline('triton.cdiv(N, meta["BLOCK"] * meta["LOOP"]), ') 

206 code.writeline(")") 

207 kernel_launch: str = f"{kernel_name}[grid](" 

208 code.writeline(kernel_launch) 

209 with code.indent(): 

210 code.writeline("src_strided, index, inp, out, ") 

211 if rank > 0: 

212 s = ", ".join(f"inp_strides[{i}]" for i in range(rank)) 

213 code.writeline(f"{s},") 

214 

215 s = ", ".join(f"index_strides[{i}]" for i in range(rank)) 

216 code.writeline(f"{s},") 

217 

218 s = ", ".join(f"src_strides[{i}]" for i in range(rank)) 

219 code.writeline(f"{s},") 

220 

221 s = ", ".join(f"index_shapes[{i}]" for i in range(rank)) 

222 code.writeline(f"{s},") 

223 

224 code.writeline("inp_size_dim,") 

225 code.writeline("stride_dim,") 

226 code.writeline("N,") 

227 

228 code.writeline(")") 

229 code.writeline("return out") 

230 

231 return code 

232 

233 

234def generate_code( 

235 inputs: Tuple[Any], 

236 wrapper_name: str, 

237 kernel_name: str, 

238 code: IndentedBuffer, 

239) -> IndentedBuffer: 

240 # inputs: [src_strided, index, inp, out, dim, M, N] 

241 shape = inputs[1].shape 

242 rank = len(shape) 

243 

244 code = generate_imports(code) 

245 code = generate_scatter_kernel(rank, kernel_name, code) 

246 code = generate_destination_passing_wrapper(rank, wrapper_name, kernel_name, code) 

247 return code 

248 

249 

250class ScatterFunction: 

251 def __init__(self): 

252 self.pid = os.getpid() 

253 self.overloads: Mapping[str, Callable] = {} 

254 

255 def __call__(self, *args, **kwargs): 

256 key = f"{self.arg_key(*args)}" 

257 if key in self.overloads: 

258 overload = self.overloads[key] 

259 else: 

260 code = IndentedBuffer() 

261 code = generate_code( 

262 args, 

263 "_scatter_add_wrapper", 

264 "_scatter_add_jit_function", 

265 code, 

266 ) 

267 

268 file_name = f"scatter_add_rank_{key}_pid_{self.pid}.py" 

269 

270 with open(code_cache_dir() / file_name, "wt", encoding="utf-8") as f: 

271 f.write(code.getvalue()) 

272 

273 # load 

274 spec = importlib.util.spec_from_file_location( 

275 f"_gen_module_rank_{key}_pid_{self.pid}", 

276 f.name, 

277 ) 

278 

279 m = importlib.util.module_from_spec(spec) 

280 spec.loader.exec_module(m) 

281 overload = getattr(m, "_scatter_add_wrapper") 

282 self.overloads[key] = overload 

283 

284 return overload(*args, **kwargs) 

285 

286 def arg_key(self, *args): 

287 tensors = [item for item in args if torch.is_tensor(item)] 

288 max_rank = max(item.ndim for item in tensors) 

289 return max_rank 

290 

291 

292_scatter_func = ScatterFunction() 

293 

294 

295def scatter_add_0(inp, dim, index, src): 

296 logger.debug("GEMS SCATTER_ADD_0") 

297 dtype_convert = False 

298 if inp.dtype == torch.float16 or inp.dtype == torch.bfloat16: 

299 out = inp.to(torch.float32) 

300 dtype_convert = True 

301 else: 

302 out = inp 

303 

304 src_strided = src.as_strided(index.shape, src.stride()) 

305 inp_restrided = restride_dim(inp, dim, index.shape) 

306 dim_size = inp.size(dim) 

307 dim_stride = inp.stride(dim) 

308 N = index.numel() 

309 

310 _scatter_func( 

311 src_strided, 

312 index, 

313 inp_restrided, 

314 out, 

315 dim_size, 

316 dim_stride, 

317 N, 

318 ) 

319 if dtype_convert: 

320 return inp.copy_(out.to(src.dtype)) 

321 return out 

322 

323 

324def clip_tensor_to_shape(b, a): 

325 target_shape = a.shape 

326 slices = [ 

327 slice(0, min(b.shape[i], target_shape[i])) for i in range(len(target_shape)) 

328 ] 

329 clipped_b = b[tuple(slices)] 

330 return clipped_b 

331 

332 

333def scatter_add_1(x, dim, index, src): 

334 logger.debug("GEMS SCATTER_ADD_1") 

335 index_dim_n = index.size(dim) 

336 inp_dim_n = x.size(dim) 

337 origin = x 

338 if dim != x.ndim - 1: 

339 x = dim_compress(x, dim) 

340 if dim != x.ndim - 1: 

341 src = dim_compress(src, dim) 

342 if dim != x.ndim - 1: 

343 index = dim_compress(index, dim) 

344 

345 all_elem = max(x.numel(), index.numel()) 

346 grid = lambda meta: (triton.cdiv(all_elem, meta["BLOCK_SIZE"] * meta["LOOP"]),) 

347 

348 dtype_convert = False 

349 if x.dtype == torch.float16 or x.dtype == torch.bfloat16: 

350 dtype_convert = True 

351 x = x.to(torch.float32) 

352 

353 scatter_add_kernel_1[grid]( 

354 index_dim_n, inp_dim_n, x, index, src, all_elem, BLOCK_SIZE=256, LOOP=1 

355 ) 

356 if dim != x.ndim - 1: 

357 order = [i for i in range(x.ndim - 1)] 

358 order.insert(dim, x.ndim - 1) 

359 if dtype_convert: 

360 return origin.copy_(x.to(src.dtype).permute(order)) 

361 return x.permute(order) 

362 else: 

363 return x.to(src.dtype) 

364 

365 

366def scatter_add_(x, dim, index, src): 

367 assert x.dim() == index.dim() and x.dim() == src.dim(), "Invalid dim" 

368 dim = dim % x.ndim 

369 assert dim >= 0 and dim < x.dim(), "Invalid dim" 

370 assert index.size(dim) <= src.size(dim), "Invalid src" 

371 equal_count = 0 

372 for d in range(x.dim()): 

373 if d != dim: 

374 assert index.size(d) <= x.size(d), "Invalid x" 

375 if index.size(d) == x.size(d): 

376 equal_count += 1 

377 else: 

378 if index.size(dim) >= x.size(dim): 

379 equal_count += 1 

380 

381 if equal_count == x.dim() and index.shape == src.shape and dim == x.ndim - 1: 

382 return scatter_add_1(x, dim, index, src) 

383 if (index.shape == src.shape and index.shape == x.shape and dim != x.ndim - 1) or ( 

384 x.shape[0] == 4096 and x.numel() >= 9437184 and dim != x.ndim - 1 

385 ): 

386 if index.shape != src.shape: 

387 src = clip_tensor_to_shape(src, index) 

388 return scatter_add_1(x, dim, index, src) 

389 else: 

390 return scatter_add_0(x, dim, index, src)