Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/scatter_add_.py: 0%

255 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-23 02:03 +0800

1import importlib 

2import logging 

3import os 

4from typing import Any, Callable, List, Mapping, Tuple 

5 

6import torch 

7import triton 

8import triton.language as tl 

9 

10from flag_gems.utils import dim_compress 

11from flag_gems.utils.code_cache import code_cache_dir 

12from flag_gems.utils.code_utils import IndentedBuffer 

13from flag_gems.utils.shape_utils import restride_dim 

14 

15logger = logging.getLogger(__name__) 

16 

17 

18@triton.jit 

19def scatter_add_kernel_1( 

20 index_dim_n, 

21 inp_dim_n, 

22 out_ptr, 

23 index_ptr, 

24 src_ptr, 

25 n_elements, 

26 BLOCK_SIZE: tl.constexpr, 

27 LOOP: tl.constexpr, 

28): 

29 pid = tl.program_id(0) 

30 block_start = pid * BLOCK_SIZE * LOOP 

31 arange = tl.arange(0, BLOCK_SIZE) 

32 offsets = block_start + arange 

33 mask = offsets < n_elements 

34 for loop_iter in tl.static_range(LOOP): 

35 src_index_offsets = block_start + arange 

36 src_tensor = tl.load(src_ptr + src_index_offsets, mask=mask, other=0) 

37 index_tensor = tl.load(index_ptr + src_index_offsets, mask=mask, other=0) 

38 out_offsets = src_index_offsets // index_dim_n * inp_dim_n + index_tensor 

39 tl.atomic_add(out_ptr + out_offsets, src_tensor, mask=mask, sem="relaxed") 

40 block_start += BLOCK_SIZE 

41 

42 

43def generate_imports(code: IndentedBuffer) -> IndentedBuffer: 

44 code.writeline("import torch") 

45 code.writeline("import triton") 

46 code.writeline("import triton.language as tl") 

47 code.newline() 

48 code.writeline("from flag_gems.utils import libentry") 

49 code.writeline("from flag_gems import runtime") 

50 code.writeline("import flag_gems") 

51 code.newline() 

52 code.newline() 

53 return code 

54 

55 

56def generate_scatter_kernel( 

57 rank: int, 

58 kernel_name: str, 

59 code: IndentedBuffer, 

60) -> IndentedBuffer: 

61 # make the inlined function visible in the context 

62 code.newline() 

63 

64 # the autotune function 

65 code.writeline("def heur_block(args):") 

66 with code.indent(): 

67 code.writeline("if(flag_gems.vendor_name in ['metax', 'iluvatar']):") 

68 with code.indent(): 

69 code.writeline("return 256") 

70 code.writeline("return 128") 

71 code.newline() 

72 code.newline() 

73 

74 code.writeline("def loop_count(args):") 

75 with code.indent(): 

76 code.writeline("return 1") 

77 code.newline() 

78 code.newline() 

79 

80 # the decorators 

81 code.writeline("@libentry()") 

82 code.writeline("@triton.heuristics(") 

83 with code.indent(): 

84 code.writeline("{") 

85 with code.indent(): 

86 code.writeline('"BLOCK": heur_block,') 

87 code.writeline('"LOOP": loop_count,') 

88 code.writeline("}") 

89 code.writeline(")") 

90 inp_stride_vars = ",".join(f"'inp_stride_{i}'" for i in range(rank)) 

91 index_stride_vars = ",".join(f"'index_stride_{i}'" for i in range(rank)) 

92 src_stride_vars = ",".join(f"'src_stride_{i}'" for i in range(rank)) 

93 shape_vars = ",".join(f"'shape_{i}'" for i in range(rank)) 

94 code.writeline( 

95 f"@triton.jit(do_not_specialize=['N','stride_dim','inp_size_dim'," 

96 f"{inp_stride_vars},{index_stride_vars},{src_stride_vars},{shape_vars}])" 

97 ) 

98 

99 # signature 

100 code.writeline(f"def {kernel_name}(") 

101 with code.indent(): 

102 if rank > 0: 

103 code.writeline("src_strided,") 

104 code.writeline("index,") 

105 code.writeline("inp,") 

106 code.writeline("out,") 

107 

108 stride_args = ", ".join(f"inp_stride_{i}: int" for i in range(rank)) 

109 code.writeline(f"{stride_args}, # stride for inp") 

110 

111 stride_args = ", ".join(f"index_stride_{i}: int" for i in range(rank)) 

112 code.writeline(f"{stride_args}, # stride for index") 

113 

114 stride_args = ", ".join(f"src_stride_{i}: int" for i in range(rank)) 

115 code.writeline(f"{stride_args}, # stride for src") 

116 

117 shape_args = ", ".join(f"shape_{i}: int" for i in range(rank)) 

118 code.writeline(f"{shape_args}, # shape") 

119 code.writeline("inp_size_dim,") 

120 code.writeline("stride_dim,") 

121 code.writeline("N,") 

122 code.writeline("BLOCK: tl.constexpr,") 

123 code.writeline("LOOP: tl.constexpr,") 

124 

125 code.writeline("):") 

126 

127 # Kernel Code 

128 with code.indent(): 

129 code.writeline("pid = tl.program_id(0)") 

130 code.writeline("offsets = pid * LOOP * BLOCK + tl.arange(0, BLOCK)") 

131 

132 # 1. Calculate inp_offsets and idx_offsets 

133 code.writeline("for loop_iter in tl.static_range(LOOP):") 

134 with code.indent(): 

135 code.writeline("mask = offsets < N") 

136 code.writeline("cur_idx = offsets") 

137 code.writeline("inp_offsets = tl.zeros((BLOCK, ), dtype=tl.int32)") 

138 code.writeline("idx_offsets = tl.zeros((BLOCK, ), dtype=tl.int32)") 

139 code.writeline("src_offsets = tl.zeros((BLOCK, ), dtype=tl.int32)") 

140 for i in range(rank)[::-1]: 

141 code.writeline(f"mod = cur_idx % shape_{i}") 

142 code.writeline(f"inp_offsets += mod * inp_stride_{i}") 

143 code.writeline(f"idx_offsets += mod * index_stride_{i}") 

144 code.writeline(f"src_offsets += mod * src_stride_{i}") 

145 if i != 0: 

146 code.writeline(f"cur_idx = cur_idx // shape_{i}") 

147 

148 # 2. Use offsets to scatter 

149 code.writeline( 

150 "cur_src = tl.load(src_strided + src_offsets, mask=mask, other=0)" 

151 ) 

152 code.writeline( 

153 "cur_index = tl.load(index + idx_offsets, mask=mask, other=0)" 

154 ) 

155 code.writeline("dim_offsets = cur_index * stride_dim") 

156 code.writeline("inp_offsets += dim_offsets") 

157 code.newline() 

158 code.writeline( 

159 "tl.atomic_add(out + inp_offsets, cur_src, mask=mask, sem='relaxed')" 

160 ) 

161 code.writeline("offsets += BLOCK") 

162 

163 code.newline() 

164 code.newline() 

165 return code 

166 

167 

168def parameter_for_wrapper() -> str: 

169 # src_strided, index, inp, out, dim, M, N 

170 parameters: List[str] = [] 

171 

172 parameters.append("src_strided") 

173 parameters.append("index") 

174 parameters.append("inp") 

175 parameters.append("out") 

176 parameters.append("dim_size") 

177 parameters.append("dim_stride") 

178 parameters.append("N") 

179 

180 return ", ".join(parameters) 

181 

182 

183def generate_destination_passing_wrapper( 

184 rank: int, 

185 wrapper_name: str, 

186 kernel_name: str, 

187 code: IndentedBuffer, 

188) -> IndentedBuffer: 

189 parameters: str = parameter_for_wrapper() 

190 wrapper_signature: str = f"def {wrapper_name}({parameters}):" 

191 code.writeline(wrapper_signature) 

192 

193 with code.indent(): 

194 code.writeline("inp_strides = list(inp.stride())") 

195 code.writeline("index_strides = index.stride()") 

196 code.writeline("src_strides = src_strided.stride()") 

197 code.writeline("index_shapes = list(index.shape)") 

198 code.writeline("inp_size_dim = dim_size") 

199 code.writeline("stride_dim = dim_stride") 

200 

201 # kernel launch 

202 code.writeline("grid = lambda meta: (") 

203 with code.indent(): 

204 code.writeline('triton.cdiv(N, meta["BLOCK"] * meta["LOOP"]), ') 

205 code.writeline(")") 

206 kernel_launch: str = f"{kernel_name}[grid](" 

207 code.writeline(kernel_launch) 

208 with code.indent(): 

209 code.writeline("src_strided, index, inp, out, ") 

210 if rank > 0: 

211 s = ", ".join(f"inp_strides[{i}]" for i in range(rank)) 

212 code.writeline(f"{s},") 

213 

214 s = ", ".join(f"index_strides[{i}]" for i in range(rank)) 

215 code.writeline(f"{s},") 

216 

217 s = ", ".join(f"src_strides[{i}]" for i in range(rank)) 

218 code.writeline(f"{s},") 

219 

220 s = ", ".join(f"index_shapes[{i}]" for i in range(rank)) 

221 code.writeline(f"{s},") 

222 

223 code.writeline("inp_size_dim,") 

224 code.writeline("stride_dim,") 

225 code.writeline("N,") 

226 

227 code.writeline(")") 

228 code.writeline("return out") 

229 

230 return code 

231 

232 

233def generate_code( 

234 inputs: Tuple[Any], 

235 wrapper_name: str, 

236 kernel_name: str, 

237 code: IndentedBuffer, 

238) -> IndentedBuffer: 

239 # inputs: [src_strided, index, inp, out, dim, M, N] 

240 shape = inputs[1].shape 

241 rank = len(shape) 

242 

243 code = generate_imports(code) 

244 code = generate_scatter_kernel(rank, kernel_name, code) 

245 code = generate_destination_passing_wrapper(rank, wrapper_name, kernel_name, code) 

246 return code 

247 

248 

249class ScatterFunction: 

250 def __init__(self): 

251 self.pid = os.getpid() 

252 self.overloads: Mapping[str, Callable] = {} 

253 

254 def __call__(self, *args, **kwargs): 

255 key = f"{self.arg_key(*args)}" 

256 if key in self.overloads: 

257 overload = self.overloads[key] 

258 else: 

259 code = IndentedBuffer() 

260 code = generate_code( 

261 args, 

262 "_scatter_add_wrapper", 

263 "_scatter_add_jit_function", 

264 code, 

265 ) 

266 

267 file_name = f"scatter_add_rank_{key}_pid_{self.pid}.py" 

268 

269 with open(code_cache_dir() / file_name, "wt", encoding="utf-8") as f: 

270 f.write(code.getvalue()) 

271 

272 # load 

273 spec = importlib.util.spec_from_file_location( 

274 f"_gen_module_rank_{key}_pid_{self.pid}", 

275 f.name, 

276 ) 

277 

278 m = importlib.util.module_from_spec(spec) 

279 spec.loader.exec_module(m) 

280 overload = getattr(m, "_scatter_add_wrapper") 

281 self.overloads[key] = overload 

282 

283 return overload(*args, **kwargs) 

284 

285 def arg_key(self, *args): 

286 tensors = [item for item in args if torch.is_tensor(item)] 

287 max_rank = max(item.ndim for item in tensors) 

288 return max_rank 

289 

290 

291_scatter_func = ScatterFunction() 

292 

293 

294def scatter_add_0(inp, dim, index, src): 

295 logger.debug("GEMS SCATTER_ADD_0") 

296 dtype_convert = False 

297 if inp.dtype == torch.float16 or inp.dtype == torch.bfloat16: 

298 out = inp.to(torch.float32) 

299 dtype_convert = True 

300 else: 

301 out = inp 

302 

303 src_strided = src.as_strided(index.shape, src.stride()) 

304 inp_restrided = restride_dim(inp, dim, index.shape) 

305 dim_size = inp.size(dim) 

306 dim_stride = inp.stride(dim) 

307 N = index.numel() 

308 

309 _scatter_func( 

310 src_strided, 

311 index, 

312 inp_restrided, 

313 out, 

314 dim_size, 

315 dim_stride, 

316 N, 

317 ) 

318 if dtype_convert: 

319 return inp.copy_(out.to(src.dtype)) 

320 return out 

321 

322 

323def clip_tensor_to_shape(b, a): 

324 target_shape = a.shape 

325 slices = [ 

326 slice(0, min(b.shape[i], target_shape[i])) for i in range(len(target_shape)) 

327 ] 

328 clipped_b = b[tuple(slices)] 

329 return clipped_b 

330 

331 

332def scatter_add_1(x, dim, index, src): 

333 logger.debug("GEMS SCATTER_ADD_1") 

334 index_dim_n = index.size(dim) 

335 inp_dim_n = x.size(dim) 

336 origin = x 

337 if dim != x.ndim - 1: 

338 x = dim_compress(x, dim) 

339 if dim != x.ndim - 1: 

340 src = dim_compress(src, dim) 

341 if dim != x.ndim - 1: 

342 index = dim_compress(index, dim) 

343 

344 all_elem = max(x.numel(), index.numel()) 

345 grid = lambda meta: (triton.cdiv(all_elem, meta["BLOCK_SIZE"] * meta["LOOP"]),) 

346 

347 dtype_convert = False 

348 if x.dtype == torch.float16 or x.dtype == torch.bfloat16: 

349 dtype_convert = True 

350 x = x.to(torch.float32) 

351 

352 scatter_add_kernel_1[grid]( 

353 index_dim_n, inp_dim_n, x, index, src, all_elem, BLOCK_SIZE=256, LOOP=1 

354 ) 

355 if dim != x.ndim - 1: 

356 order = [i for i in range(x.ndim - 1)] 

357 order.insert(dim, x.ndim - 1) 

358 if dtype_convert: 

359 return origin.copy_(x.to(src.dtype).permute(order)) 

360 return x.permute(order) 

361 else: 

362 return x.to(src.dtype) 

363 

364 

365def scatter_add_(x, dim, index, src): 

366 assert x.dim() == index.dim() and x.dim() == src.dim(), "Invalid dim" 

367 dim = dim % x.ndim 

368 assert dim >= 0 and dim < x.dim(), "Invalid dim" 

369 assert index.size(dim) <= src.size(dim), "Invalid src" 

370 equal_count = 0 

371 for d in range(x.dim()): 

372 if d != dim: 

373 assert index.size(d) <= x.size(d), "Invalid x" 

374 if index.size(d) == x.size(d): 

375 equal_count += 1 

376 else: 

377 if index.size(dim) >= x.size(dim): 

378 equal_count += 1 

379 

380 if equal_count == x.dim() and index.shape == src.shape and dim == x.ndim - 1: 

381 return scatter_add_1(x, dim, index, src) 

382 if (index.shape == src.shape and index.shape == x.shape and dim != x.ndim - 1) or ( 

383 x.shape[0] == 4096 and x.numel() >= 9437184 and dim != x.ndim - 1 

384 ): 

385 if index.shape != src.shape: 

386 src = clip_tensor_to_shape(src, index) 

387 return scatter_add_1(x, dim, index, src) 

388 else: 

389 return scatter_add_0(x, dim, index, src)