Coverage for src/flag_gems/ops/index_add.py: 100%

157 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-19 02:32 +0800

1import importlib 

2import logging 

3import os 

4from typing import Any, Callable, List, Mapping, Tuple 

5 

6import torch 

7 

8from flag_gems.utils.code_cache import code_cache_dir 

9from flag_gems.utils.code_utils import IndentedBuffer 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14def generate_imports(code: IndentedBuffer) -> IndentedBuffer: 

15 code.writeline("import triton") 

16 code.writeline("import triton.language as tl") 

17 code.writeline("from flag_gems.utils import libentry") 

18 

19 code.newline() 

20 code.newline() 

21 

22 return code 

23 

24 

25def generate_index_add_kernel( 

26 rank: int, 

27 kernel_name: str, 

28 code: IndentedBuffer, 

29) -> IndentedBuffer: 

30 # the decorators 

31 code.writeline("@libentry()") 

32 code.writeline("@triton.jit") 

33 

34 # signature 

35 code.writeline(f"def {kernel_name}(") 

36 with code.indent(): 

37 if rank > 0: 

38 code.writeline("index,") 

39 code.writeline("src,") 

40 code.writeline("out,") 

41 code.writeline("N,") 

42 code.writeline("inp_numel,") 

43 code.writeline("inp_stride_dim,") 

44 code.writeline("inp_shape_dim,") 

45 code.writeline("src_shape_dim,") 

46 code.writeline("delta,") 

47 code.writeline("alpha,") 

48 

49 stride_args = ", ".join(f"src_stride_{i}: int" for i in range(rank)) 

50 code.writeline(f"{stride_args}, # stride for src") 

51 

52 shape_args = ", ".join(f"src_shape_{i}: int" for i in range(rank)) 

53 code.writeline(f"{shape_args}, # shape for src") 

54 

55 code.writeline("BLOCK_SIZE: tl.constexpr,") 

56 

57 code.writeline("):") 

58 

59 # Kernel Code 

60 with code.indent(): 

61 code.writeline("pid = tl.program_id(axis=0)") 

62 code.writeline("offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)") 

63 code.writeline("mask = offsets < N") 

64 

65 for i in range(rank - 1, -1, -1): 

66 code.writeline(f"src_offset{i} = offsets % src_shape_{i}") 

67 code.writeline(f"offsets = offsets // src_shape_{i}") 

68 code.newline() 

69 comp = [f"src_offset{i} * src_stride_{i}" for i in range(rank)] 

70 code.writeline(f"src_offset = {' + '.join(comp)}") 

71 

72 code.writeline("pre_cal = (inp_stride_dim * src_shape_dim)") 

73 

74 # index add 

75 code.writeline("pre_idx = (src_offset // pre_cal).to(tl.int64)") 

76 code.writeline( 

77 "dim_idx = (src_offset % pre_cal // inp_stride_dim).to(tl.int64)" 

78 ) 

79 code.writeline( 

80 "src_dim_idx = (tl.load(index + dim_idx, mask=mask, other=0)).to(tl.int64)" 

81 ) 

82 code.writeline( 

83 'assert src_dim_idx >= 0 and src_dim_idx < inp_shape_dim, "0 <= index < self.size(dim)"' 

84 ) 

85 code.writeline( 

86 "input_idx = (src_offset + (delta * pre_idx + src_dim_idx - dim_idx) * inp_stride_dim).to(tl.int64)" 

87 ) 

88 

89 code.writeline("input_mask = input_idx < inp_numel") 

90 code.writeline( 

91 "add_on = tl.load(src + src_offset, mask=mask, other=0) * alpha" 

92 ) 

93 code.writeline( 

94 "tl.atomic_add(out + input_idx, add_on, mask=input_mask, sem='relaxed')" 

95 ) 

96 # TODO: tl.atomic_add doesn't support bfloat16! The following method may be unsafe. 

97 # code.writeline("cur_out = tl.load(out + input_idx, mask=input_mask)") 

98 # code.writeline("tl.store(out + input_idx, cur_out + add_on, mask=input_mask)") 

99 

100 code.newline() 

101 code.newline() 

102 return code 

103 

104 

105def parameter_for_wrapper() -> str: 

106 # out, index, src, dim, inp_stride_dim, src_shape_dim, delta, N, inp.numel(), alpha 

107 parameters: List[str] = [] 

108 parameters.append("out") 

109 parameters.append("index") 

110 parameters.append("src") 

111 parameters.append("dim") 

112 parameters.append("inp_stride_dim") 

113 parameters.append("inp_shape_dim") 

114 parameters.append("src_shape_dim") 

115 parameters.append("delta") 

116 parameters.append("N") 

117 parameters.append("inp_numel") 

118 parameters.append("alpha") 

119 

120 return ", ".join(parameters) 

121 

122 

123def generate_destination_passing_wrapper( 

124 rank: int, 

125 wrapper_name: str, 

126 kernel_name: str, 

127 code: IndentedBuffer, 

128) -> IndentedBuffer: 

129 parameters: str = parameter_for_wrapper() 

130 wrapper_signature: str = f"def {wrapper_name} ({parameters}):" 

131 code.writeline(wrapper_signature) 

132 

133 with code.indent(): 

134 code.writeline("src_strides = list(src.stride())") 

135 code.writeline("src_shapes = list(src.shape)") 

136 

137 # kernel launch 

138 code.writeline("BLOCK_SIZE = 128") # BLOCK_SIZE setting 

139 code.writeline("grid = (triton.cdiv(N, BLOCK_SIZE),)") 

140 kernel_launch: str = f"{kernel_name}[grid](" 

141 code.writeline(kernel_launch) 

142 with code.indent(): 

143 code.writeline( 

144 "index, src, out, N, inp_numel, inp_stride_dim, inp_shape_dim, src_shape_dim, delta, alpha, " 

145 ) 

146 if rank > 0: 

147 s = ", ".join(f"src_strides[{i}]" for i in range(rank)) 

148 code.writeline(f"{s},") 

149 

150 s = ", ".join(f"src_shapes[{i}]" for i in range(rank)) 

151 code.writeline(f"{s},") 

152 code.writeline("BLOCK_SIZE=BLOCK_SIZE") 

153 code.writeline(")") 

154 code.writeline("return out") 

155 

156 return code 

157 

158 

159def generate_code( 

160 inputs: Tuple[Any], 

161 wrapper_name: str, 

162 kernel_name: str, 

163 code: IndentedBuffer, 

164) -> IndentedBuffer: 

165 # inputs: [out, index, src, dim, inp_stride_dim, inp_shape_dim, src_shape_dim, delta, N, inp.numel(), alpha] 

166 shape = inputs[2].shape 

167 rank = len(shape) 

168 

169 code = generate_imports(code) 

170 code = generate_index_add_kernel(rank, kernel_name, code) 

171 code = generate_destination_passing_wrapper(rank, wrapper_name, kernel_name, code) 

172 return code 

173 

174 

175class IndexAddFunction: 

176 def __init__(self): 

177 self.pid = os.getpid() 

178 self.overloads: Mapping[str, Callable] = {} 

179 

180 def __call__(self, *args, **kwargs): 

181 key = f"{self.arg_key(*args)}" 

182 if key in self.overloads: 

183 overload = self.overloads[key] 

184 else: 

185 code = IndentedBuffer() 

186 code = generate_code( 

187 args, 

188 "_index_add_wrapper", 

189 "_index_add_jit_function", 

190 code, 

191 ) 

192 

193 file_name = f"index_add_rank_{key}_pid_{self.pid}.py" 

194 

195 with open(code_cache_dir() / file_name, "wt", encoding="utf-8") as f: 

196 f.write(code.getvalue()) 

197 

198 # load 

199 spec = importlib.util.spec_from_file_location( 

200 f"_gen_module_rank_{key}_pid_{self.pid}", 

201 f.name, 

202 ) 

203 

204 m = importlib.util.module_from_spec(spec) 

205 spec.loader.exec_module(m) 

206 overload = getattr(m, "_index_add_wrapper") 

207 self.overloads[key] = overload 

208 

209 return overload(*args, **kwargs) 

210 

211 def arg_key(self, *args): 

212 tensors = [item for item in args if torch.is_tensor(item)] 

213 max_rank = max(item.ndim for item in tensors) 

214 return max_rank 

215 

216 

217_index_add_func = IndexAddFunction() 

218 

219 

220def index_add(inp, dim, index, src, alpha=1): 

221 logger.debug("GEMS INDEX ADD") 

222 assert ((0 <= index) * (index < inp.size(dim))).equal( 

223 torch.ones(tuple(index.shape), dtype=torch.bool, device=inp.device) 

224 ), "0 <= index < self.size(dim)" 

225 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

226 assert index.numel() == src.size( 

227 dim 

228 ), "The dimth dimension of source must have the same size as the length of index" 

229 assert ( 

230 inp.ndim == src.ndim 

231 ), "Self and source should have the same number of dimensions" 

232 assert ( 

233 ((inp.size(i) == src.size(i)) or i == dim) for i in range(0, inp.ndim) 

234 ), "src.size(d) == self.size(d) for all dimensions d != dim" 

235 

236 out = inp.clone() 

237 

238 dim %= inp.ndim 

239 inp_stride_dim = inp.stride(dim) 

240 src_shape_dim = src.size(dim) 

241 inp_shape_dim = inp.size(dim) 

242 delta = inp.size(dim) - src_shape_dim 

243 N = src.numel() 

244 

245 _index_add_func( 

246 out, 

247 index, 

248 src, 

249 dim, 

250 inp_stride_dim, 

251 inp_shape_dim, 

252 src_shape_dim, 

253 delta, 

254 N, 

255 inp.numel(), 

256 alpha, 

257 ) 

258 return out 

259 

260 

261def index_add_(inp, dim, index, src, alpha=1): 

262 logger.debug("GEMS INDEX ADD_") 

263 assert ((0 <= index) * (index < inp.size(dim))).equal( 

264 torch.ones(tuple(index.shape), dtype=torch.bool, device=inp.device) 

265 ), "0 <= index < self.size(dim)" 

266 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

267 assert index.numel() == src.size( 

268 dim 

269 ), "The dimth dimension of source must have the same size as the length of index" 

270 assert ( 

271 inp.ndim == src.ndim 

272 ), "Self and source should have the same number of dimensions" 

273 assert ( 

274 ((inp.size(i) == src.size(i)) or i == dim) for i in range(0, inp.ndim) 

275 ), "src.size(d) == self.size(d) for all dimensions d != dim" 

276 

277 dim %= inp.ndim 

278 inp_stride_dim = inp.stride(dim) 

279 src_shape_dim = src.size(dim) 

280 inp_shape_dim = inp.size(dim) 

281 delta = inp.size(dim) - src_shape_dim 

282 N = src.numel() 

283 

284 _index_add_func( 

285 inp, 

286 index, 

287 src, 

288 dim, 

289 inp_stride_dim, 

290 inp_shape_dim, 

291 src_shape_dim, 

292 delta, 

293 N, 

294 inp.numel(), 

295 alpha, 

296 ) 

297 return inp