Coverage for src/flag_gems/runtime/backend/_cambricon/ops/index_add.py: 0%

193 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-16 02:02 +0800

1import importlib 

2import logging 

3import os 

4from typing import Any, Callable, List, Mapping, Tuple 

5 

6import torch 

7import triton 

8import triton.language as tl 

9 

10from flag_gems.utils import dim_compress, libentry 

11from flag_gems.utils.code_cache import code_cache_dir 

12from flag_gems.utils.code_utils import IndentedBuffer 

13 

14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

15 

16 

17def cfggen(): 

18 block_m = [1, 2, 4, 8] 

19 block_n = [128, 1024, 2048, 4096] 

20 configs = [ 

21 triton.Config({"BLOCK_M": m, "BLOCK_N": n}, num_warps=1) 

22 for m in block_m 

23 for n in block_n 

24 ] 

25 return configs 

26 

27 

28@libentry() 

29@triton.autotune(configs=cfggen(), key=["M", "N"]) 

30@triton.jit 

31def index_add_kernel( 

32 inp, 

33 out, 

34 index, 

35 src, 

36 M, 

37 N, 

38 alpha, 

39 inp_len, 

40 BLOCK_M: tl.constexpr, 

41 BLOCK_N: tl.constexpr, 

42): 

43 pid_x = tl.program_id(axis=0) 

44 pid_y = tl.program_id(axis=1) 

45 rows_offsets = pid_x * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

46 cols_offsets = pid_y * BLOCK_N + tl.arange(0, BLOCK_N) 

47 

48 rows_mask = rows_offsets < M 

49 index_mask = cols_offsets < N 

50 block_mask = rows_mask and index_mask 

51 

52 cur_indices = tl.load(index + cols_offsets, mask=index_mask, other=0) 

53 inp_off = rows_offsets * inp_len + cur_indices[None, :] 

54 cur_inp = tl.load(inp + inp_off, mask=block_mask, other=0.0) 

55 src_off = rows_offsets * N + cols_offsets[None, :] 

56 cur_src = tl.load(src + src_off, mask=block_mask, other=0.0) 

57 cur_inp += alpha * cur_src 

58 

59 tl.store(out + inp_off, cur_inp, mask=block_mask) 

60 

61 

62def index_add(inp, dim, index, src, alpha=1): 

63 logger.debug("GEMS_CAMBRICON INDEX ADD") 

64 assert ((0 <= index) * (index < inp.size(dim))).equal( 

65 torch.ones(tuple(index.shape), dtype=torch.bool, device="cuda") 

66 ), "0 <= index < self.size(dim)" 

67 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

68 assert index.numel() == src.size( 

69 dim 

70 ), "The dimth dimension of source must have the same size as the length of index" 

71 assert ( 

72 inp.ndim == src.ndim 

73 ), "Self and source should have the same number of dimensions" 

74 assert ( 

75 ((inp.size(i) == src.size(i)) or i == dim) for i in range(0, inp.ndim) 

76 ), "src.size(d) == self.size(d) for all dimensions d != dim" 

77 

78 inp = inp.contiguous() 

79 index = index.contiguous() 

80 src = src.contiguous() 

81 

82 dim = dim % inp.ndim 

83 inp_len = inp.size(dim) 

84 N = index.numel() 

85 M = src.numel() // N 

86 fine_dim = inp.ndim - 1 

87 if dim != fine_dim: 

88 inp = dim_compress(inp, dim) 

89 src = dim_compress(src, dim) 

90 out = inp.clone() 

91 

92 grid = lambda meta: ( 

93 triton.cdiv(M, meta["BLOCK_M"]), 

94 triton.cdiv(N, meta["BLOCK_N"]), 

95 ) 

96 index_add_kernel[grid](inp, out, index, src, M, N, alpha, inp_len) 

97 if dim != fine_dim: 

98 order = [i for i in range(out.ndim - 1)] 

99 order.insert(dim, fine_dim) 

100 return out.permute(order).contiguous() 

101 else: 

102 return out 

103 

104 

105def generate_imports(code: IndentedBuffer) -> IndentedBuffer: 

106 code.writeline("import triton") 

107 code.writeline("import triton.language as tl") 

108 code.writeline("from flag_gems.utils import libentry") 

109 

110 code.newline() 

111 code.newline() 

112 

113 return code 

114 

115 

116def generate_index_add_kernel( 

117 rank: int, 

118 kernel_name: str, 

119 code: IndentedBuffer, 

120) -> IndentedBuffer: 

121 # the decorators 

122 code.writeline("@libentry()") 

123 code.writeline("@triton.jit") 

124 

125 # signature 

126 code.writeline(f"def {kernel_name}(") 

127 with code.indent(): 

128 if rank > 0: 

129 code.writeline("index,") 

130 code.writeline("src,") 

131 code.writeline("out,") 

132 code.writeline("N,") 

133 code.writeline("inp_numel,") 

134 code.writeline("inp_stride_dim,") 

135 code.writeline("inp_shape_dim,") 

136 code.writeline("src_shape_dim,") 

137 code.writeline("delta,") 

138 code.writeline("alpha,") 

139 

140 stride_args = ", ".join(f"src_stride_{i}: int" for i in range(rank)) 

141 code.writeline(f"{stride_args}, # stride for src") 

142 

143 shape_args = ", ".join(f"src_shape_{i}: int" for i in range(rank)) 

144 code.writeline(f"{shape_args}, # shape for src") 

145 

146 code.writeline("BLOCK_SIZE: tl.constexpr,") 

147 

148 code.writeline("):") 

149 

150 # Kernel Code 

151 with code.indent(): 

152 code.writeline("pid = tl.program_id(axis=0)") 

153 code.writeline("offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)") 

154 code.writeline("mask = offsets < N") 

155 

156 for i in range(rank - 1, -1, -1): 

157 code.writeline(f"src_offset{i} = offsets % src_shape_{i}") 

158 code.writeline(f"offsets = offsets // src_shape_{i}") 

159 code.newline() 

160 comp = [f"src_offset{i} * src_stride_{i}" for i in range(rank)] 

161 code.writeline(f"src_offset = {' + '.join(comp)}") 

162 

163 code.writeline("pre_cal = (inp_stride_dim * src_shape_dim)") 

164 

165 # index add 

166 code.writeline("pre_idx = (src_offset // pre_cal).to(tl.int64)") 

167 code.writeline( 

168 "dim_idx = (src_offset % pre_cal // inp_stride_dim).to(tl.int64)" 

169 ) 

170 code.writeline( 

171 "src_dim_idx = (tl.load(index + dim_idx, mask=mask, other=0)).to(tl.int64)" 

172 ) 

173 code.writeline( 

174 'assert src_dim_idx >= 0 and src_dim_idx < inp_shape_dim, "0 <= index < self.size(dim)"' 

175 ) 

176 code.writeline( 

177 "input_idx = (src_offset + (delta * pre_idx + src_dim_idx - dim_idx) * inp_stride_dim).to(tl.int64)" 

178 ) 

179 

180 code.writeline("input_mask = input_idx < inp_numel") 

181 code.writeline( 

182 "add_on = tl.load(src + src_offset, mask=mask, other=0) * alpha" 

183 ) 

184 code.writeline( 

185 "tl.atomic_add(out + input_idx, add_on, mask=input_mask, sem='relaxed')" 

186 ) 

187 # TODO: tl.atomic_add doesn't support bfloat16! The following method may be unsafe. 

188 # code.writeline("cur_out = tl.load(out + input_idx, mask=input_mask)") 

189 # code.writeline("tl.store(out + input_idx, cur_out + add_on, mask=input_mask)") 

190 

191 code.newline() 

192 code.newline() 

193 return code 

194 

195 

196def parameter_for_wrapper() -> str: 

197 # out, index, src, dim, inp_stride_dim, src_shape_dim, delta, N, inp.numel(), alpha 

198 parameters: List[str] = [] 

199 parameters.append("out") 

200 parameters.append("index") 

201 parameters.append("src") 

202 parameters.append("dim") 

203 parameters.append("inp_stride_dim") 

204 parameters.append("inp_shape_dim") 

205 parameters.append("src_shape_dim") 

206 parameters.append("delta") 

207 parameters.append("N") 

208 parameters.append("inp_numel") 

209 parameters.append("alpha") 

210 

211 return ", ".join(parameters) 

212 

213 

214def generate_destination_passing_wrapper( 

215 rank: int, 

216 wrapper_name: str, 

217 kernel_name: str, 

218 code: IndentedBuffer, 

219) -> IndentedBuffer: 

220 parameters: str = parameter_for_wrapper() 

221 wrapper_signature: str = f"def {wrapper_name} ({parameters}):" 

222 code.writeline(wrapper_signature) 

223 

224 with code.indent(): 

225 code.writeline("src_strides = list(src.stride())") 

226 code.writeline("src_shapes = list(src.shape)") 

227 

228 # kernel launch 

229 code.writeline("BLOCK_SIZE = 640") # BLOCK_SIZE setting 

230 code.writeline("grid = (triton.cdiv(N, BLOCK_SIZE),)") 

231 kernel_launch: str = f"{kernel_name}[grid](" 

232 code.writeline(kernel_launch) 

233 with code.indent(): 

234 code.writeline( 

235 "index, src, out, N, inp_numel, inp_stride_dim, inp_shape_dim, src_shape_dim, delta, alpha, " 

236 ) 

237 if rank > 0: 

238 s = ", ".join(f"src_strides[{i}]" for i in range(rank)) 

239 code.writeline(f"{s},") 

240 

241 s = ", ".join(f"src_shapes[{i}]" for i in range(rank)) 

242 code.writeline(f"{s},") 

243 code.writeline("BLOCK_SIZE=BLOCK_SIZE") 

244 code.writeline(")") 

245 code.writeline("return out") 

246 

247 return code 

248 

249 

250def generate_code( 

251 inputs: Tuple[Any], 

252 wrapper_name: str, 

253 kernel_name: str, 

254 code: IndentedBuffer, 

255) -> IndentedBuffer: 

256 # inputs: [out, index, src, dim, inp_stride_dim, inp_shape_dim, src_shape_dim, delta, N, inp.numel(), alpha] 

257 shape = inputs[2].shape 

258 rank = len(shape) 

259 

260 code = generate_imports(code) 

261 code = generate_index_add_kernel(rank, kernel_name, code) 

262 code = generate_destination_passing_wrapper(rank, wrapper_name, kernel_name, code) 

263 return code 

264 

265 

266class IndexAddFunction: 

267 def __init__(self): 

268 self.pid = os.getpid() 

269 self.overloads: Mapping[str, Callable] = {} 

270 

271 def __call__(self, *args, **kwargs): 

272 key = f"{self.arg_key(*args)}" 

273 if key in self.overloads: 

274 overload = self.overloads[key] 

275 else: 

276 code = IndentedBuffer() 

277 code = generate_code( 

278 args, 

279 "_index_add_wrapper", 

280 "_index_add_jit_function", 

281 code, 

282 ) 

283 

284 file_name = f"index_add_rank_{key}_pid_{self.pid}.py" 

285 

286 with open(code_cache_dir() / file_name, "wt", encoding="utf-8") as f: 

287 f.write(code.getvalue()) 

288 

289 # load 

290 spec = importlib.util.spec_from_file_location( 

291 f"_gen_module_rank_{key}_pid_{self.pid}", 

292 f.name, 

293 ) 

294 

295 m = importlib.util.module_from_spec(spec) 

296 spec.loader.exec_module(m) 

297 overload = getattr(m, "_index_add_wrapper") 

298 self.overloads[key] = overload 

299 

300 return overload(*args, **kwargs) 

301 

302 def arg_key(self, *args): 

303 tensors = [item for item in args if torch.is_tensor(item)] 

304 max_rank = max(item.ndim for item in tensors) 

305 return max_rank 

306 

307 

308_index_add_func = IndexAddFunction() 

309 

310 

311def index_add_(inp, dim, index, src, alpha=1): 

312 logger.debug("GEMS_CAMBRICON INDEX ADD_") 

313 assert ((0 <= index) * (index < inp.size(dim))).equal( 

314 torch.ones(tuple(index.shape), dtype=torch.bool, device=inp.device) 

315 ), "0 <= index < self.size(dim)" 

316 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

317 assert index.numel() == src.size( 

318 dim 

319 ), "The dimth dimension of source must have the same size as the length of index" 

320 assert ( 

321 inp.ndim == src.ndim 

322 ), "Self and source should have the same number of dimensions" 

323 assert ( 

324 ((inp.size(i) == src.size(i)) or i == dim) for i in range(0, inp.ndim) 

325 ), "src.size(d) == self.size(d) for all dimensions d != dim" 

326 

327 dim %= inp.ndim 

328 inp_stride_dim = inp.stride(dim) 

329 src_shape_dim = src.size(dim) 

330 inp_shape_dim = inp.size(dim) 

331 delta = inp.size(dim) - src_shape_dim 

332 N = src.numel() 

333 

334 _index_add_func( 

335 inp, 

336 index, 

337 src, 

338 dim, 

339 inp_stride_dim, 

340 inp_shape_dim, 

341 src_shape_dim, 

342 delta, 

343 N, 

344 inp.numel(), 

345 alpha, 

346 ) 

347 return inp