Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/scatter.py: 0%

239 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import importlib 

2import logging 

3import os 

4from typing import Any, Callable, List, Mapping, Tuple 

5 

6import torch 

7 

8from flag_gems.utils.code_cache import code_cache_dir 

9from flag_gems.utils.code_utils import IndentedBuffer, write_atomic 

10from flag_gems.utils.shape_utils import ( 

11 MemOverlap, 

12 has_internal_overlapping, 

13 restride_dim, 

14) 

15 

16logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

17 

18 

19def generate_imports(code: IndentedBuffer) -> IndentedBuffer: 

20 code.writeline("import torch") 

21 code.writeline("import triton") 

22 code.writeline("import triton.language as tl") 

23 code.newline() 

24 code.writeline("from flag_gems.utils import libentry") 

25 code.writeline("from flag_gems import runtime") 

26 code.writeline("import flag_gems") 

27 # code.writeline("from flag_gems.utils import triton_lang_extension as tle") 

28 code.newline() 

29 code.newline() 

30 return code 

31 

32 

33def generate_scatter_kernel( 

34 rank: int, 

35 kernel_name: str, 

36 code: IndentedBuffer, 

37) -> IndentedBuffer: 

38 # make the inlined function visible in the context 

39 code.newline() 

40 

41 # the autotune function 

42 

43 code.writeline("def heur_block(args):") 

44 with code.indent(): 

45 code.writeline( 

46 'return triton.next_power_of_2(triton.cdiv(triton.cdiv(args["N"], 12), 4))' 

47 ) # LOOP = 4 

48 code.newline() 

49 code.newline() 

50 

51 code.writeline("def loop_count(args):") 

52 with code.indent(): 

53 code.writeline("return 4") 

54 code.newline() 

55 code.newline() 

56 

57 # the decorators 

58 code.writeline("@libentry()") 

59 code.writeline("@triton.heuristics(") 

60 with code.indent(): 

61 code.writeline("{") 

62 with code.indent(): 

63 code.writeline('"BLOCK": heur_block,') 

64 code.writeline('"LOOP": loop_count,') 

65 code.writeline("}") 

66 code.writeline(")") 

67 inp_stride_vars = ",".join(f"'inp_stride_{i}'" for i in range(rank)) 

68 index_stride_vars = ",".join(f"'index_stride_{i}'" for i in range(rank)) 

69 src_stride_vars = ",".join(f"'src_stride_{i}'" for i in range(rank)) 

70 shape_vars = ",".join(f"'shape_{i}'" for i in range(rank)) 

71 code.writeline( 

72 f"@triton.jit(do_not_specialize=['N','stride_dim','inp_size_dim'," 

73 f"{inp_stride_vars},{index_stride_vars},{src_stride_vars},{shape_vars}])" 

74 ) 

75 

76 # signature 

77 code.writeline(f"def {kernel_name}(") 

78 with code.indent(): 

79 if rank > 0: 

80 code.writeline("src_strided,") 

81 code.writeline("index,") 

82 code.writeline("inp,") 

83 code.writeline("out,") 

84 

85 stride_args = ", ".join(f"inp_stride_{i}: int" for i in range(rank)) 

86 code.writeline(f"{stride_args}, # stride for inp") 

87 

88 stride_args = ", ".join(f"index_stride_{i}: int" for i in range(rank)) 

89 code.writeline(f"{stride_args}, # stride for index") 

90 

91 stride_args = ", ".join(f"src_stride_{i}: int" for i in range(rank)) 

92 code.writeline(f"{stride_args}, # stride for src") 

93 

94 shape_args = ", ".join(f"shape_{i}: int" for i in range(rank)) 

95 code.writeline(f"{shape_args}, # shape") 

96 code.writeline("inp_size_dim,") 

97 code.writeline("stride_dim,") 

98 code.writeline("N,") 

99 # reduce options 

100 code.writeline("IS_ADD: tl.constexpr,") 

101 code.writeline("IS_MUL: tl.constexpr,") 

102 code.writeline("BLOCK: tl.constexpr,") 

103 code.writeline("LOOP: tl.constexpr,") 

104 code.writeline("INT32_OFFSET: tl.constexpr") 

105 

106 code.writeline("):") 

107 

108 # Kernel Code 

109 with code.indent(): 

110 code.writeline("pid = tl.program_id(0)") 

111 code.writeline("if not INT32_OFFSET:") 

112 with code.indent(): 

113 code.writeline("pid = pid.to(tl.int64)") 

114 code.writeline("offsets = pid * LOOP * BLOCK + tl.arange(0, BLOCK)") 

115 

116 # 1. Calculate inp_offsets and idx_offsets 

117 code.writeline("for loop_iter in tl.static_range(LOOP):") 

118 with code.indent(): 

119 code.writeline("mask = offsets < N") 

120 code.writeline("cur_idx = offsets") 

121 code.writeline("if INT32_OFFSET:") 

122 with code.indent(): 

123 code.writeline("inp_offsets = tl.zeros((BLOCK, ), dtype=tl.int32)") 

124 code.writeline("idx_offsets = tl.zeros((BLOCK, ), dtype=tl.int32)") 

125 code.writeline("src_offsets = tl.zeros((BLOCK, ), dtype=tl.int32)") 

126 code.writeline("else:") 

127 with code.indent(): 

128 code.writeline("inp_offsets = tl.zeros((BLOCK, ), dtype=tl.int64)") 

129 code.writeline("idx_offsets = tl.zeros((BLOCK, ), dtype=tl.int64)") 

130 code.writeline("src_offsets = tl.zeros((BLOCK, ), dtype=tl.int64)") 

131 for i in range(rank)[::-1]: 

132 code.writeline("if INT32_OFFSET:") 

133 with code.indent(): 

134 code.writeline(f"shape_{i} = shape_{i}.to(tl.int32)") 

135 code.writeline(f"inp_stride_{i} = inp_stride_{i}.to(tl.int32)") 

136 code.writeline(f"index_stride_{i} = index_stride_{i}.to(tl.int32)") 

137 code.writeline(f"src_stride_{i} = src_stride_{i}.to(tl.int32)") 

138 code.writeline(f"mod = cur_idx % shape_{i}") 

139 code.writeline(f"inp_offsets += mod * inp_stride_{i}") 

140 code.writeline(f"idx_offsets += mod * index_stride_{i}") 

141 code.writeline(f"src_offsets += mod * src_stride_{i}") 

142 if i != 0: 

143 code.writeline(f"cur_idx = cur_idx // shape_{i}") 

144 

145 # 2. Use offsets to scatter 

146 code.writeline( 

147 "cur_src = tl.load(src_strided + src_offsets, mask=mask, other=0)" 

148 ) 

149 code.writeline( 

150 "cur_index = tl.load(index + idx_offsets, mask=mask, other=0)" 

151 ) 

152 code.writeline("if INT32_OFFSET:") 

153 with code.indent(): 

154 code.writeline("cur_index = cur_index.to(tl.int32)") 

155 code.writeline("stride_dim = stride_dim.to(tl.int32)") 

156 

157 code.writeline("dim_offsets = cur_index * stride_dim") 

158 code.writeline("inp_offsets += dim_offsets") 

159 code.newline() 

160 code.writeline("if IS_ADD: ") 

161 with code.indent(): 

162 code.writeline( 

163 "tl.atomic_add(out + inp_offsets, cur_src, mask=mask, sem='relaxed')" 

164 ) 

165 code.writeline("elif IS_MUL: ") 

166 with code.indent(): 

167 code.writeline( 

168 "tl.atomic_mul(out + inp_offsets, cur_src, mask=mask, sem='relaxed')" 

169 ) 

170 

171 code.writeline("else: ") 

172 with code.indent(): 

173 code.writeline("tl.store(out + inp_offsets, cur_src, mask=mask)") 

174 

175 code.writeline("offsets += BLOCK") 

176 

177 code.newline() 

178 code.newline() 

179 return code 

180 

181 

182def parameter_for_wrapper() -> str: 

183 # src_strided, index, inp, out, dim, M, N, reduce 

184 parameters: List[str] = [] 

185 

186 parameters.append("src_strided") 

187 parameters.append("index") 

188 parameters.append("inp") 

189 parameters.append("out") 

190 parameters.append("dim_size") 

191 parameters.append("dim_stride") 

192 parameters.append("N") 

193 parameters.append("reduce: tl.constexpr=None") 

194 parameters.append("int32_offset: tl.constexpr=None") 

195 

196 return ", ".join(parameters) 

197 

198 

199def generate_destination_passing_wrapper( 

200 rank: int, 

201 wrapper_name: str, 

202 kernel_name: str, 

203 code: IndentedBuffer, 

204) -> IndentedBuffer: 

205 parameters: str = parameter_for_wrapper() 

206 wrapper_signature: str = f"def {wrapper_name}({parameters}):" 

207 code.writeline(wrapper_signature) 

208 

209 with code.indent(): 

210 code.writeline("inp_strides = list(inp.stride())") 

211 code.writeline("index_strides = index.stride()") 

212 code.writeline("src_strides = src_strided.stride()") 

213 code.writeline("index_shapes = list(index.shape)") 

214 code.writeline("inp_size_dim = dim_size") 

215 code.writeline("stride_dim = dim_stride") 

216 

217 code.writeline('IS_ADD = reduce == "add"') 

218 code.writeline('IS_MUL = reduce == "multiply"') 

219 code.writeline("int32_offset = int32_offset or True") 

220 

221 # kernel launch 

222 code.writeline("grid = lambda meta: (") 

223 with code.indent(): 

224 code.writeline('triton.cdiv(N, meta["BLOCK"] * meta["LOOP"]), ') 

225 code.writeline(")") 

226 

227 kernel_launch: str = f"{kernel_name}[grid](" 

228 code.writeline(kernel_launch) 

229 

230 with code.indent(): 

231 code.writeline("src_strided, index, inp, out, ") 

232 if rank > 0: 

233 s = ", ".join(f"inp_strides[{i}]" for i in range(rank)) 

234 code.writeline(f"{s},") 

235 

236 s = ", ".join(f"index_strides[{i}]" for i in range(rank)) 

237 code.writeline(f"{s},") 

238 

239 s = ", ".join(f"src_strides[{i}]" for i in range(rank)) 

240 code.writeline(f"{s},") 

241 

242 s = ", ".join(f"index_shapes[{i}]" for i in range(rank)) 

243 code.writeline(f"{s},") 

244 

245 code.writeline("inp_size_dim,") 

246 code.writeline("stride_dim,") 

247 code.writeline("N,") 

248 # reduce options 

249 code.writeline("IS_ADD,") 

250 code.writeline("IS_MUL,") 

251 code.writeline("INT32_OFFSET=int32_offset,") 

252 # code.writeline("buffer_size_limit=512,") 

253 # code.writeline("isCloseUnrollControl=True,") 

254 

255 code.writeline(")") 

256 code.writeline("return out") 

257 

258 return code 

259 

260 

261def generate_code( 

262 inputs: Tuple[Any], 

263 wrapper_name: str, 

264 kernel_name: str, 

265 code: IndentedBuffer, 

266) -> IndentedBuffer: 

267 # inputs: [src_strided, index, inp, out, dim, M, N, reduce] 

268 shape = inputs[1].shape 

269 rank = len(shape) 

270 

271 code = generate_imports(code) 

272 code = generate_scatter_kernel(rank, kernel_name, code) 

273 code = generate_destination_passing_wrapper(rank, wrapper_name, kernel_name, code) 

274 return code 

275 

276 

277class ScatterFunction: 

278 def __init__(self): 

279 self.pid = os.getpid() 

280 self.overloads: Mapping[str, Callable] = {} 

281 

282 def __call__(self, *args, **kwargs): 

283 key = f"{self.arg_key(*args)}" 

284 if key in self.overloads: 

285 overload = self.overloads[key] 

286 else: 

287 code = IndentedBuffer() 

288 code = generate_code( 

289 args, 

290 "_scatter_wrapper", 

291 "_scatter_jit_function", 

292 code, 

293 ) 

294 

295 file_name = f"scatter_rank_{key}.py" 

296 file_path = code_cache_dir() / file_name 

297 write_atomic(file_path, code.getvalue()) 

298 

299 # load 

300 spec = importlib.util.spec_from_file_location( 

301 f"_gen_module_rank_{key}", 

302 file_path, 

303 ) 

304 

305 m = importlib.util.module_from_spec(spec) 

306 spec.loader.exec_module(m) 

307 overload = getattr(m, "_scatter_wrapper") 

308 self.overloads[key] = overload 

309 

310 return overload(*args, **kwargs) 

311 

312 def arg_key(self, *args): 

313 tensors = [item for item in args if torch.is_tensor(item)] 

314 max_rank = max(item.ndim for item in tensors) 

315 return max_rank 

316 

317 

318_scatter_func = ScatterFunction() 

319 

320 

321def scatter(inp, dim, index, src, reduce=None): 

322 logger.debug("GEMS SCATTER") 

323 out = inp.clone() 

324 

325 if reduce is not None: 

326 assert inp.dtype not in ( 

327 torch.bfloat16, 

328 ), "Unsupported operation: reduce scatter bfloat tensors." 

329 

330 if has_internal_overlapping(out) == MemOverlap.Yes: 

331 out = out.contiguous() 

332 

333 src_strided = src.as_strided(index.shape, src.stride()) 

334 inp_restrided = restride_dim(inp, dim, index.shape) 

335 dim_size = inp.size(dim) 

336 dim_stride = inp.stride(dim) 

337 N = index.numel() 

338 

339 int32_size_dim = lambda x: x.stride(dim) * x.size(dim) < 2**32 

340 use_int32_offset = all(map(int32_size_dim, (inp, index, src))) 

341 _scatter_func( 

342 src_strided, 

343 index, 

344 inp_restrided, 

345 out, 

346 dim_size, 

347 dim_stride, 

348 N, 

349 reduce, 

350 int32_offset=use_int32_offset, 

351 ) 

352 

353 return out 

354 

355 

356def scatter_(inp, dim, index, src, reduce=None): 

357 logger.debug("GEMS SCATTER_") 

358 out = inp 

359 

360 if reduce is not None: 

361 assert inp.dtype not in ( 

362 torch.bfloat16, 

363 ), "Unsupported operation: reduce scatter bfloat tensors." 

364 

365 assert ( 

366 has_internal_overlapping(out) != MemOverlap.Yes 

367 ), "Unsupported operation: trying to inplace write to an internally overlapping tensor." 

368 

369 src_restrided = src.as_strided(index.shape, src.stride()) 

370 inp_restrided = restride_dim(inp, dim, index.shape) 

371 dim_size = inp.size(dim) 

372 dim_stride = inp.stride(dim) 

373 N = index.numel() 

374 

375 int32_size_dim = lambda x: x.stride(dim) * x.size(dim) < 2**32 

376 use_int32_offset = all(map(int32_size_dim, (inp, index, src))) 

377 _scatter_func( 

378 src_restrided, 

379 index, 

380 inp_restrided, 

381 out, 

382 dim_size, 

383 dim_stride, 

384 N, 

385 reduce, 

386 int32_offset=use_int32_offset, 

387 ) 

388 

389 return inp