Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/gather.py: 0%

189 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1import importlib 

2import logging 

3import os 

4from typing import Any, Callable, List, Mapping, Tuple 

5 

6import torch 

7 

8from flag_gems.utils.code_cache import cache_dir 

9from flag_gems.utils.code_utils import IndentedBuffer 

10from flag_gems.utils.shape_utils import restride_dim 

11 

12from .scatter import scatter_ 

13 

14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

15 

16 

17def generate_imports(code: IndentedBuffer) -> IndentedBuffer: 

18 code.writeline("import torch") 

19 code.writeline("import triton") 

20 code.writeline("import triton.language as tl") 

21 code.writeline("import builtins") 

22 code.newline() 

23 code.writeline("from flag_gems.utils import libentry") 

24 code.writeline("from flag_gems import runtime") 

25 code.writeline("from flag_gems.utils import triton_lang_extension as tle") 

26 

27 code.newline() 

28 code.newline() 

29 return code 

30 

31 

32def generate_gather_kernel( 

33 rank: int, 

34 kernel_name: str, 

35 code: IndentedBuffer, 

36) -> IndentedBuffer: 

37 # make the inlined function visible in the context 

38 code.newline() 

39 

40 # the autotune function 

41 code.writeline("def cfggen():") 

42 with code.indent(): 

43 code.writeline("block_m = [1, 2, 4, 8]") 

44 code.writeline("block_n = [256, 512, 1024, 2048]") 

45 code.writeline("configs = [") 

46 with code.indent(): 

47 code.writeline('triton.Config({"BLOCK_M": m, "BLOCK_N": n}, num_warps=4)') 

48 code.writeline("for m in block_m") 

49 code.writeline("for n in block_n") 

50 code.writeline("]") 

51 code.writeline("return configs") 

52 

53 code.newline() 

54 code.newline() 

55 

56 code.writeline("def heur_block_m(args):") 

57 with code.indent(): 

58 code.writeline('return triton.next_power_of_2(triton.cdiv(args["M"], 12))') 

59 

60 code.newline() 

61 

62 code.writeline("def heur_block_n(args):") 

63 with code.indent(): 

64 code.writeline('return builtins.min(triton.next_power_of_2(args["N"]), 4096)') 

65 

66 code.newline() 

67 code.newline() 

68 

69 # the decorators 

70 code.writeline("@libentry()") 

71 # code.writeline('@triton.autotune(configs=cfggen(), key=["M", "N"])') 

72 code.writeline("@triton.heuristics(") 

73 with code.indent(): 

74 code.writeline("values={") 

75 with code.indent(): 

76 code.writeline('"BLOCK_M": heur_block_m,') 

77 code.writeline('"BLOCK_N": heur_block_n,') 

78 code.writeline("},") 

79 code.writeline(")") 

80 code.writeline("@triton.jit") 

81 

82 # signature 

83 code.writeline(f"def {kernel_name}(") 

84 with code.indent(): 

85 if rank > 0: 

86 code.writeline("inp,") 

87 code.writeline("out,") 

88 code.writeline("index,") 

89 

90 stride_args = ", ".join( 

91 f"inp_stride_{i}: tl.constexpr" for i in range(rank) 

92 ) 

93 code.writeline(f"{stride_args}, # stride for inp") 

94 

95 stride_args = ", ".join( 

96 f"index_stride_{i}: tl.constexpr" for i in range(rank) 

97 ) 

98 code.writeline(f"{stride_args}, # stride for index") 

99 

100 shape_args = ", ".join( 

101 f"index_shape_{i}: tl.constexpr" for i in range(rank) 

102 ) 

103 code.writeline(f"{shape_args}, # shape for index") 

104 

105 code.writeline("dim: tl.constexpr,") 

106 code.writeline("stride_dim: tl.constexpr,") 

107 code.writeline("M: tl.constexpr,") 

108 code.writeline("N: tl.constexpr,") 

109 code.writeline("BLOCK_M: tl.constexpr,") 

110 code.writeline("BLOCK_N: tl.constexpr,") 

111 code.writeline("):") 

112 

113 # Kernel Code 

114 with code.indent(): 

115 code.writeline("pid_x = tle.program_id(0)") 

116 code.writeline("pid_y = tle.program_id(1)") 

117 code.writeline( 

118 "rows_offsets = pid_x * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]" 

119 ) 

120 code.writeline( 

121 "cols_offsets = pid_y * BLOCK_N + tl.arange(0, BLOCK_N)[None, :]" 

122 ) 

123 code.writeline("rows_mask = rows_offsets < M") 

124 code.writeline("cols_mask = cols_offsets < N") 

125 

126 code.writeline("offsets = (rows_offsets * N + cols_offsets).to(tl.int64)") 

127 code.writeline("mask = rows_mask & cols_mask") 

128 

129 # 1. Calculate inp_offsets and idx_offsets 

130 code.writeline("inp_offsets = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int64)") 

131 code.writeline("idx_offsets = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int64)") 

132 code.writeline("cur_idx = rows_offsets * N + cols_offsets") 

133 

134 # 2. snippets 

135 for i in range(rank): 

136 code.writeline(f"mod = cur_idx % index_shape_{i}") 

137 code.writeline(f"inp_offsets += mod * inp_stride_{i}") 

138 code.writeline(f"idx_offsets += mod * index_stride_{i}") 

139 if i != (rank - 1): 

140 code.writeline(f"cur_idx //= index_shape_{i}") 

141 

142 # Use offsets to gather 

143 code.writeline("cur_index = tl.load(index + idx_offsets, mask=mask, other=0)") 

144 code.writeline("inp_offsets += cur_index * stride_dim") 

145 code.writeline("cur_inp = tl.load(inp + inp_offsets, mask=mask, other=0)") 

146 code.writeline("tl.store(out + idx_offsets, cur_inp, mask=mask)") 

147 

148 code.newline() 

149 code.newline() 

150 return code 

151 

152 

153def parameter_for_wrapper() -> str: 

154 # inp_strided, out, index, dim, stride_dim, M, N 

155 parameters: List[str] = [] 

156 

157 parameters.append("inp_strided") 

158 parameters.append("out") 

159 parameters.append("index") 

160 parameters.append("dim") 

161 parameters.append("stride_dim") 

162 parameters.append("M") 

163 parameters.append("N") 

164 

165 return ", ".join(parameters) 

166 

167 

168def generate_gather_wrapper( 

169 rank: int, 

170 wrapper_name: str, 

171 kernel_name: str, 

172 code: IndentedBuffer, 

173) -> IndentedBuffer: 

174 parameters: str = parameter_for_wrapper() 

175 wrapper_signature: str = f"def {wrapper_name}({parameters}):" 

176 code.writeline(wrapper_signature) 

177 

178 with code.indent(): 

179 code.writeline("inp_strides = inp_strided.stride()") 

180 code.writeline("index_strides = index.stride()") 

181 code.writeline("index_shapes = list(index.shape)") 

182 

183 # kernel launch 

184 code.writeline("grid = lambda meta: (") 

185 with code.indent(): 

186 code.writeline('triton.cdiv(M, meta["BLOCK_M"]),') 

187 code.writeline('triton.cdiv(N, meta["BLOCK_N"])') 

188 code.writeline(")") 

189 

190 kernel_launch: str = f"{kernel_name}[grid](" 

191 code.writeline(kernel_launch) 

192 

193 with code.indent(): 

194 code.writeline("inp_strided, out, index, ") 

195 if rank > 0: 

196 s = ", ".join(f"inp_strides[{i}]" for i in range(rank)) 

197 code.writeline(f"{s},") 

198 

199 s = ", ".join(f"index_strides[{i}]" for i in range(rank)) 

200 code.writeline(f"{s},") 

201 

202 s = ", ".join(f"index_shapes[{i}]" for i in range(rank)) 

203 code.writeline(f"{s},") 

204 

205 code.writeline("dim,") 

206 code.writeline("stride_dim,") 

207 code.writeline("M,") 

208 code.writeline("N,") 

209 code.writeline(")") 

210 code.writeline("return out") 

211 

212 return code 

213 

214 

215def generate_code( 

216 inputs: Tuple[Any], 

217 wrapper_name: str, 

218 kernel_name: str, 

219 code: IndentedBuffer, 

220) -> IndentedBuffer: 

221 # inputs: inp_strided, out, index, dim, stride_dim, M, N 

222 shape = inputs[2].shape 

223 rank = len(shape) 

224 

225 code = generate_imports(code) 

226 code = generate_gather_kernel(rank, kernel_name, code) 

227 code = generate_gather_wrapper(rank, wrapper_name, kernel_name, code) 

228 return code 

229 

230 

231class GatherFunction: 

232 def __init__(self): 

233 self.pid = os.getpid() 

234 self.overloads: Mapping[str, Callable] = {} 

235 

236 def __call__(self, *args, **kwargs): 

237 key = f"{self.arg_key(*args)}" 

238 if key in self.overloads: 

239 overload = self.overloads[key] 

240 else: 

241 code = IndentedBuffer() 

242 code = generate_code( 

243 args, 

244 "_gather_wrapper", 

245 "_gather_jit_function", 

246 code, 

247 ) 

248 

249 file_name = f"gather_rank_{key}_pid_{self.pid}.py" 

250 

251 with open(cache_dir() / file_name, "wt", encoding="utf-8") as f: 

252 f.write(code.getvalue()) 

253 

254 # load 

255 spec = importlib.util.spec_from_file_location( 

256 f"_gen_module_rank_{key}_pid_{self.pid}", 

257 f.name, 

258 ) 

259 

260 m = importlib.util.module_from_spec(spec) 

261 spec.loader.exec_module(m) 

262 overload = getattr(m, "_gather_wrapper") 

263 self.overloads[key] = overload 

264 

265 return overload(*args, **kwargs) 

266 

267 def arg_key(self, *args): 

268 tensors = [item for item in args if torch.is_tensor(item)] 

269 max_rank = max(item.ndim for item in tensors) 

270 return max_rank 

271 

272 

273_gather_func = GatherFunction() 

274 

275 

276def gather(inp, dim, index, out=None, sparse_grad=False): 

277 logger.debug("GEMS GATHER") 

278 inp = inp.contiguous() 

279 index = index.contiguous() 

280 if out is None: 

281 out = torch.empty_like(index, dtype=inp.dtype, device=inp.device) 

282 out = out.contiguous() 

283 stride_dim = inp.stride(dim) 

284 

285 inp_strided = restride_dim(inp, dim, index.shape) 

286 # plain_idx = torch.arange(0, index.numel(), device=inp.device).reshape(index.shape) 

287 N = list(index.shape)[index.ndim - 1] 

288 M = index.numel() // N 

289 

290 _gather_func(inp_strided, out, index, dim, stride_dim, M, N) 

291 return out 

292 

293 

294def gather_backward(grad, self, dim, index, sparse_grad): 

295 logger.debug("GEMS GATHER BACKWARD") 

296 result = grad.new_zeros(self.shape) 

297 return scatter_(result, dim, index, grad, reduce="add")