Coverage for src/flag_gems/runtime/backend/_cambricon/ops/scatter.py: 0%

184 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-23 02:03 +0800

1import importlib 

2import logging 

3import os 

4from typing import Any, Callable, List, Mapping, Tuple 

5 

6import torch 

7 

8from flag_gems.utils.code_cache import code_cache_dir 

9from flag_gems.utils.code_utils import IndentedBuffer 

10 

11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

12 

13 

14def generate_imports(code: IndentedBuffer) -> IndentedBuffer: 

15 code.writeline("import torch") 

16 code.writeline("import triton") 

17 code.writeline("import triton.language as tl") 

18 code.newline() 

19 code.writeline("from flag_gems.utils import libentry, libtuner") 

20 code.writeline("from flag_gems import runtime") 

21 code.writeline("from flag_gems.utils import triton_lang_extension as tle") 

22 code.newline() 

23 code.newline() 

24 return code 

25 

26 

27def generate_scatter_kernel( 

28 rank: int, 

29 dim: int, 

30 large_tensor: bool, 

31 kernel_name: str, 

32 code: IndentedBuffer, 

33) -> IndentedBuffer: 

34 # make the inlined function visible in the context 

35 code.newline() 

36 

37 # the autotune function 

38 

39 code.newline() 

40 code.newline() 

41 

42 # the decorators 

43 code.writeline("@libentry()") 

44 code.writeline( 

45 '@libtuner(configs=runtime.get_tuned_config("scatter"), key=["N"], strategy=["log"],' 

46 ) 

47 code.writeline(' restore_value=["out"], )') 

48 

49 code.writeline("@triton.jit") 

50 

51 # signature 

52 code.writeline(f"def {kernel_name}(") 

53 with code.indent(): 

54 if rank > 0: 

55 code.writeline("src,") 

56 code.writeline("index,") 

57 code.writeline("inp,") 

58 code.writeline("out,") 

59 

60 stride_args = ", ".join(f"inp_stride_{i}: int" for i in range(rank)) 

61 code.writeline(f"{stride_args}, # stride for inp") 

62 

63 stride_args = ", ".join(f"src_stride_{i}: int" for i in range(rank)) 

64 code.writeline(f"{stride_args}, # stride for src") 

65 

66 shape_args = ", ".join(f"index_shape_{i}: int" for i in range(rank)) 

67 code.writeline(f"{shape_args}, # shape for index") 

68 

69 code.writeline("dim,") 

70 code.writeline("stride_dim,") 

71 code.writeline("N,") 

72 # reduce options 

73 code.writeline("IS_ADD: tl.constexpr,") 

74 code.writeline("IS_MUL: tl.constexpr,") 

75 code.writeline("BLOCK_SIZE: tl.constexpr,") 

76 

77 code.writeline("):") 

78 

79 # Kernel Code 

80 with code.indent(): 

81 code.writeline("pid = tl.program_id(0)") 

82 code.writeline("offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)") 

83 code.writeline("mask = offsets < N") 

84 

85 # 1. Calculate inp_offsets and src_offsets 

86 if large_tensor: 

87 code.writeline("inp_offsets = tl.zeros((BLOCK_SIZE,), dtype=tl.int64)") 

88 code.writeline("src_offsets = tl.zeros((BLOCK_SIZE,), dtype=tl.int64)") 

89 else: 

90 code.writeline("inp_offsets = tl.zeros((BLOCK_SIZE,), dtype=tl.int32)") 

91 code.writeline("src_offsets = tl.zeros((BLOCK_SIZE,), dtype=tl.int32)") 

92 

93 code.writeline("cur_idx = offsets") 

94 

95 # 2. snippets 

96 for i in range(rank - 1, -1, -1): 

97 code.writeline(f"mod = cur_idx % index_shape_{i}") 

98 if dim != i: 

99 code.writeline(f"inp_offsets += mod * inp_stride_{i}") 

100 code.writeline(f"src_offsets += mod * src_stride_{i}") 

101 # the last "//" should be optimized out 

102 code.writeline(f"cur_idx = cur_idx // index_shape_{i}") 

103 

104 # 3. Use offsets to scatter 

105 code.writeline("cur_src = tl.load(src + src_offsets, mask=mask, other=0)") 

106 if large_tensor: 

107 code.writeline("cur_index = tl.load(index + offsets, mask=mask, other=0)") 

108 else: 

109 code.writeline( 

110 "cur_index = tl.load(index + offsets, mask=mask, other=0).to(tl.int32)" 

111 ) 

112 code.writeline("inp_offsets += cur_index * stride_dim") 

113 

114 code.newline() 

115 code.writeline("if IS_ADD: ") 

116 with code.indent(): 

117 code.writeline("cur_inp = tl.load(inp + inp_offsets, mask=mask, other=0)") 

118 code.writeline("res = cur_inp + cur_src") 

119 code.writeline("tl.store(out + inp_offsets, res, mask=mask)") 

120 

121 code.writeline("elif IS_MUL: ") 

122 with code.indent(): 

123 code.writeline("cur_inp = tl.load(inp + inp_offsets, mask=mask, other=0)") 

124 code.writeline("res = cur_inp * cur_src") 

125 code.writeline("tl.store(out + inp_offsets, res, mask=mask)") 

126 

127 code.writeline("else: ") 

128 with code.indent(): 

129 code.writeline("tl.store(out + inp_offsets, cur_src, mask=mask)") 

130 

131 code.newline() 

132 code.newline() 

133 return code 

134 

135 

136def parameter_for_wrapper() -> str: 

137 # src, index, inp, out, dim, reduce, N 

138 parameters: List[str] = [] 

139 

140 parameters.append("src") 

141 parameters.append("index") 

142 parameters.append("inp") 

143 parameters.append("out") 

144 parameters.append("dim") 

145 parameters.append("reduce") 

146 parameters.append("N") 

147 

148 return ", ".join(parameters) 

149 

150 

151def generate_destination_passing_wrapper( 

152 rank: int, 

153 wrapper_name: str, 

154 kernel_name: str, 

155 code: IndentedBuffer, 

156) -> IndentedBuffer: 

157 parameters: str = parameter_for_wrapper() 

158 wrapper_signature: str = f"def {wrapper_name}({parameters}):" 

159 code.writeline(wrapper_signature) 

160 

161 with code.indent(): 

162 code.writeline("inp_strides = list(inp.stride())") 

163 code.writeline("src_strides = src.stride()") 

164 code.writeline("index_shapes = list(index.shape)") 

165 code.writeline("stride_dim = inp_strides[dim]") 

166 code.writeline("inp_strides[dim] = 0") 

167 

168 code.writeline('IS_ADD = reduce == "add"') 

169 code.writeline('IS_MUL = reduce == "multiply"') 

170 

171 # kernel launch 

172 code.writeline("grid = lambda meta: (") 

173 with code.indent(): 

174 code.writeline('triton.cdiv(N, meta["BLOCK_SIZE"]),') 

175 code.writeline(")") 

176 

177 kernel_launch: str = f"{kernel_name}[grid](" 

178 code.writeline(kernel_launch) 

179 

180 with code.indent(): 

181 code.writeline("src, index, inp, out, ") 

182 if rank > 0: 

183 s = ", ".join(f"inp_strides[{i}]" for i in range(rank)) 

184 code.writeline(f"{s},") 

185 

186 s = ", ".join(f"src_strides[{i}]" for i in range(rank)) 

187 code.writeline(f"{s},") 

188 

189 s = ", ".join(f"index_shapes[{i}]" for i in range(rank)) 

190 code.writeline(f"{s},") 

191 

192 code.writeline("dim,") 

193 code.writeline("stride_dim,") 

194 code.writeline("N,") 

195 # reduce options 

196 code.writeline("IS_ADD,") 

197 code.writeline("IS_MUL,") 

198 code.writeline(")") 

199 code.writeline("return out") 

200 

201 return code 

202 

203 

204def generate_code( 

205 rank: int, 

206 dim: int, 

207 large_input: bool, 

208 inputs: Tuple[Any], 

209 wrapper_name: str, 

210 kernel_name: str, 

211 code: IndentedBuffer, 

212) -> IndentedBuffer: 

213 # inputs: [src, index, inp, out, dim, reduce, N] 

214 shape = inputs[1].shape 

215 rank = len(shape) 

216 

217 code = generate_imports(code) 

218 code = generate_scatter_kernel(rank, dim, large_input, kernel_name, code) 

219 code = generate_destination_passing_wrapper(rank, wrapper_name, kernel_name, code) 

220 return code 

221 

222 

223class ScatterFunction: 

224 def __init__(self): 

225 self.pid = os.getpid() 

226 self.overloads: Mapping[str, Callable] = {} 

227 

228 def __call__(self, *args, **kwargs): 

229 rank = kwargs["rank"] 

230 dim = kwargs["dim"] 

231 large_tensor = kwargs["large_tensor"] 

232 

233 key = f"{self.arg_key(*args)}_{rank}_{dim}_{large_tensor}" 

234 if key in self.overloads: 

235 overload = self.overloads[key] 

236 else: 

237 code = IndentedBuffer() 

238 code = generate_code( 

239 rank, 

240 dim, 

241 large_tensor, 

242 args, 

243 "_scatter_wrapper", 

244 "_scatter_jit_function", 

245 code, 

246 ) 

247 

248 file_name = f"scatter_rank_{key}_pid_{self.pid}.py" 

249 

250 with open(code_cache_dir() / file_name, "wt", encoding="utf-8") as f: 

251 f.write(code.getvalue()) 

252 

253 # load 

254 spec = importlib.util.spec_from_file_location( 

255 f"_gen_module_rank_{key}_pid_{self.pid}", 

256 f.name, 

257 ) 

258 

259 m = importlib.util.module_from_spec(spec) 

260 spec.loader.exec_module(m) 

261 overload = getattr(m, "_scatter_wrapper") 

262 self.overloads[key] = overload 

263 

264 return overload(*args) 

265 

266 def arg_key(self, *args): 

267 tensors = [item for item in args if torch.is_tensor(item)] 

268 max_rank = max(item.ndim for item in tensors) 

269 return max_rank 

270 

271 

272_scatter_func = ScatterFunction() 

273 

274 

275def scatter(inp, dim, index, src, reduce=None): 

276 logger.debug("GEMS_CAMBRICON SCATTER") 

277 inp = inp.contiguous() 

278 index = index.contiguous() 

279 src = src.contiguous() 

280 out = inp.clone() 

281 

282 N = index.numel() 

283 

284 large_tensor = (src.numel() * src.element_size() > 2**31) or ( 

285 out.numel() * out.element_size() > 2**31 

286 ) 

287 

288 # <rank>_<dim>_<large_tensor> is part of the key of overloads 

289 _scatter_func( 

290 src, 

291 index, 

292 inp, 

293 out, 

294 dim, 

295 reduce, 

296 N, 

297 rank=len(index.shape), 

298 large_tensor=large_tensor, 

299 dim=dim, 

300 ) 

301 return out 

302 

303 

304def scatter_(inp, dim, index, src, reduce=None): 

305 logger.debug("GEMS_CAMBRICON SCATTER_") 

306 inp = inp.contiguous() 

307 index = index.contiguous() 

308 src = src.contiguous() 

309 out = inp 

310 

311 N = index.numel() 

312 

313 large_tensor = (src.numel() * src.element_size() > 2**31) or ( 

314 out.numel() * out.element_size() > 2**31 

315 ) 

316 

317 _scatter_func( 

318 src, 

319 index, 

320 inp, 

321 out, 

322 dim, 

323 reduce, 

324 N, 

325 rank=len(index.shape), 

326 large_tensor=large_tensor, 

327 dim=dim, 

328 ) 

329 

330 return inp