Coverage for src/flag_gems/runtime/backend/_ascend/ops/index_add.py: 0%

145 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-12 02:21 +0800

1import importlib 

2import logging 

3import os 

4from typing import Any, Callable, List, Mapping, Tuple 

5 

6import torch 

7 

8from flag_gems.utils.code_cache import code_cache_dir 

9from flag_gems.utils.code_utils import IndentedBuffer 

10 

11logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

12 

13 

14def generate_imports(code: IndentedBuffer) -> IndentedBuffer: 

15 code.writeline("import triton") 

16 code.writeline("import triton.language as tl") 

17 code.writeline("from flag_gems.utils import libentry") 

18 code.writeline("from flag_gems import runtime") 

19 code.newline() 

20 code.newline() 

21 

22 return code 

23 

24 

25def generate_index_add_kernel( 

26 rank: int, 

27 kernel_name: str, 

28 code: IndentedBuffer, 

29) -> IndentedBuffer: 

30 # the decorators 

31 code.writeline("@libentry()") 

32 code.writeline( 

33 '@triton.autotune(configs=runtime.get_tuned_config("index_add"), key=["BLOCK_SIZE"])' 

34 ) 

35 code.writeline("@triton.jit") 

36 

37 # signature 

38 code.writeline(f"def {kernel_name}(") 

39 with code.indent(): 

40 if rank > 0: 

41 code.writeline("index,") 

42 code.writeline("src,") 

43 code.writeline("out,") 

44 code.writeline("N,") 

45 code.writeline("inp_numel,") 

46 code.writeline("inp_stride_dim,") 

47 code.writeline("inp_shape_dim,") 

48 code.writeline("src_shape_dim,") 

49 code.writeline("delta,") 

50 code.writeline("alpha,") 

51 

52 stride_args = ", ".join(f"src_stride_{i}: int" for i in range(rank)) 

53 code.writeline(f"{stride_args}, # stride for src") 

54 

55 shape_args = ", ".join(f"src_shape_{i}: int" for i in range(rank)) 

56 code.writeline(f"{shape_args}, # shape for src") 

57 

58 code.writeline("BLOCK_SIZE: tl.constexpr,") 

59 

60 code.writeline("):") 

61 

62 # Kernel Code 

63 with code.indent(): 

64 code.writeline("pid = tl.program_id(axis=0)") 

65 code.writeline("offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)") 

66 code.writeline("mask = offsets < N") 

67 

68 for i in range(rank - 1, -1, -1): 

69 code.writeline(f"src_offset{i} = offsets % src_shape_{i}") 

70 code.writeline(f"offsets = offsets // src_shape_{i}") 

71 code.newline() 

72 comp = [f"src_offset{i} * src_stride_{i}" for i in range(rank)] 

73 code.writeline(f"src_offset = {' + '.join(comp)}") 

74 

75 code.writeline("pre_cal = (inp_stride_dim * src_shape_dim)") 

76 

77 # index add 

78 code.writeline("pre_idx = (src_offset // pre_cal).to(tl.int64)") 

79 code.writeline( 

80 "dim_idx = (src_offset % pre_cal // inp_stride_dim).to(tl.int64)" 

81 ) 

82 code.writeline( 

83 "src_dim_idx = (tl.load(index + dim_idx, mask=mask, other=0)).to(tl.int64)" 

84 ) 

85 code.writeline( 

86 'assert src_dim_idx >= 0 and src_dim_idx < inp_shape_dim, "0 <= index < self.size(dim)"' 

87 ) 

88 code.writeline( 

89 "input_idx = (src_offset + (delta * pre_idx + src_dim_idx - dim_idx) * inp_stride_dim).to(tl.int64)" 

90 ) 

91 

92 code.writeline("input_mask = input_idx < inp_numel") 

93 code.writeline( 

94 "add_on = tl.load(src + src_offset, mask=mask, other=0) * alpha" 

95 ) 

96 code.writeline( 

97 "tl.atomic_add(out + input_idx, add_on, mask=input_mask, sem='relaxed')" 

98 ) 

99 # TODO: tl.atomic_add doesn't support bfloat16! The following method may be unsafe. 

100 # code.writeline("cur_out = tl.load(out + input_idx, mask=input_mask)") 

101 # code.writeline("tl.store(out + input_idx, cur_out + add_on, mask=input_mask)") 

102 

103 code.newline() 

104 code.newline() 

105 return code 

106 

107 

108def parameter_for_wrapper() -> str: 

109 # out, index, src, dim, inp_stride_dim, src_shape_dim, delta, N, inp.numel(), alpha 

110 parameters: List[str] = [] 

111 parameters.append("out") 

112 parameters.append("index") 

113 parameters.append("src") 

114 parameters.append("dim") 

115 parameters.append("inp_stride_dim") 

116 parameters.append("inp_shape_dim") 

117 parameters.append("src_shape_dim") 

118 parameters.append("delta") 

119 parameters.append("N") 

120 parameters.append("inp_numel") 

121 parameters.append("alpha") 

122 

123 return ", ".join(parameters) 

124 

125 

126def generate_destination_passing_wrapper( 

127 rank: int, 

128 wrapper_name: str, 

129 kernel_name: str, 

130 code: IndentedBuffer, 

131) -> IndentedBuffer: 

132 parameters: str = parameter_for_wrapper() 

133 wrapper_signature: str = f"def {wrapper_name} ({parameters}):" 

134 code.writeline(wrapper_signature) 

135 

136 with code.indent(): 

137 code.writeline("src_strides = list(src.stride())") 

138 code.writeline("src_shapes = list(src.shape)") 

139 

140 # kernel launch 

141 code.writeline("grid = lambda meta: (") 

142 with code.indent(): 

143 code.writeline("triton.cdiv(N, meta['BLOCK_SIZE']), ") 

144 code.writeline(")") 

145 kernel_launch: str = f"{kernel_name}[grid](" 

146 code.writeline(kernel_launch) 

147 with code.indent(): 

148 code.writeline( 

149 "index, src, out, N, inp_numel, inp_stride_dim, inp_shape_dim, src_shape_dim, delta, alpha, " 

150 ) 

151 if rank > 0: 

152 s = ", ".join(f"src_strides[{i}]" for i in range(rank)) 

153 code.writeline(f"{s},") 

154 

155 s = ", ".join(f"src_shapes[{i}]" for i in range(rank)) 

156 code.writeline(f"{s},") 

157 # code.writeline("BLOCK_SIZE=BLOCK_SIZE") 

158 code.writeline(")") 

159 code.writeline("return out") 

160 

161 return code 

162 

163 

164def generate_code( 

165 inputs: Tuple[Any], 

166 wrapper_name: str, 

167 kernel_name: str, 

168 code: IndentedBuffer, 

169) -> IndentedBuffer: 

170 # inputs: [out, index, src, dim, inp_stride_dim, inp_shape_dim, src_shape_dim, delta, N, inp.numel(), alpha] 

171 shape = inputs[2].shape 

172 rank = len(shape) 

173 

174 code = generate_imports(code) 

175 code = generate_index_add_kernel(rank, kernel_name, code) 

176 code = generate_destination_passing_wrapper(rank, wrapper_name, kernel_name, code) 

177 return code 

178 

179 

180class IndexAddFunction: 

181 def __init__(self): 

182 self.pid = os.getpid() 

183 self.overloads: Mapping[str, Callable] = {} 

184 

185 def __call__(self, *args, **kwargs): 

186 key = f"{self.arg_key(*args)}" 

187 if key in self.overloads: 

188 overload = self.overloads[key] 

189 else: 

190 code = IndentedBuffer() 

191 code = generate_code( 

192 args, 

193 "_index_add_wrapper", 

194 "_index_add_jit_function", 

195 code, 

196 ) 

197 

198 file_name = f"index_add_rank_{key}_pid_{self.pid}.py" 

199 

200 with open(code_cache_dir() / file_name, "wt", encoding="utf-8") as f: 

201 f.write(code.getvalue()) 

202 

203 # load 

204 spec = importlib.util.spec_from_file_location( 

205 f"_gen_module_rank_{key}_pid_{self.pid}", 

206 f.name, 

207 ) 

208 

209 m = importlib.util.module_from_spec(spec) 

210 spec.loader.exec_module(m) 

211 overload = getattr(m, "_index_add_wrapper") 

212 self.overloads[key] = overload 

213 

214 return overload(*args, **kwargs) 

215 

216 def arg_key(self, *args): 

217 tensors = [item for item in args if torch.is_tensor(item)] 

218 max_rank = max(item.ndim for item in tensors) 

219 return max_rank 

220 

221 

222_index_add_func = IndexAddFunction() 

223 

224 

225def index_add(inp, dim, index, src, alpha=1): 

226 logger.debug("GEMS_ASCEND INDEX ADD") 

227 assert ((0 <= index).to(torch.int8) * (index < inp.size(dim))).equal( 

228 torch.ones(tuple(index.shape), dtype=torch.int8, device=inp.device) 

229 ), "0 <= index < self.size(dim)" 

230 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

231 assert index.numel() == src.size( 

232 dim 

233 ), "The dimth dimension of source must have the same size as the length of index" 

234 assert ( 

235 inp.ndim == src.ndim 

236 ), "Self and source should have the same number of dimensions" 

237 assert ( 

238 ((inp.size(i) == src.size(i)) or i == dim) for i in range(0, inp.ndim) 

239 ), "src.size(d) == self.size(d) for all dimensions d != dim" 

240 

241 out = inp.clone() 

242 

243 dim %= inp.ndim 

244 inp_stride_dim = inp.stride(dim) 

245 src_shape_dim = src.size(dim) 

246 inp_shape_dim = inp.size(dim) 

247 delta = inp.size(dim) - src_shape_dim 

248 N = src.numel() 

249 

250 _index_add_func( 

251 out, 

252 index, 

253 src, 

254 dim, 

255 inp_stride_dim, 

256 inp_shape_dim, 

257 src_shape_dim, 

258 delta, 

259 N, 

260 inp.numel(), 

261 alpha, 

262 ) 

263 return out