Coverage for src/flag_gems/runtime/backend/_cambricon/ops/gather.py: 0%

157 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-19 02:32 +0800

1import importlib 

2import logging 

3import os 

4from typing import Any, Callable, List, Mapping, Tuple 

5 

6import torch 

7 

8from flag_gems.utils.code_cache import code_cache_dir 

9from flag_gems.utils.code_utils import IndentedBuffer 

10from flag_gems.utils.shape_utils import restride_dim 

11 

12from .scatter import scatter 

13 

14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

15 

16 

17def generate_imports(code: IndentedBuffer) -> IndentedBuffer: 

18 code.writeline("import torch") 

19 code.writeline("import triton") 

20 code.writeline("import triton.language as tl") 

21 code.newline() 

22 code.writeline("from flag_gems.utils import libentry, libtuner") 

23 code.writeline("from flag_gems import runtime") 

24 code.writeline("from flag_gems.utils import triton_lang_extension as tle") 

25 

26 code.newline() 

27 code.newline() 

28 return code 

29 

30 

31def generate_gather_kernel( 

32 dim: int, 

33 large_input: bool, 

34 rank: int, 

35 kernel_name: str, 

36 code: IndentedBuffer, 

37) -> IndentedBuffer: 

38 # make the inlined function visible in the context 

39 code.newline() 

40 

41 # the decorators 

42 code.writeline("@libentry()") 

43 code.writeline( 

44 '@libtuner(configs=runtime.get_tuned_config("gather"), key=["N"], strategy=["log"])' 

45 ) 

46 code.writeline("@triton.jit") 

47 

48 # signature 

49 code.writeline(f"def {kernel_name}(") 

50 with code.indent(): 

51 if rank > 0: 

52 code.writeline("inp,") 

53 code.writeline("out,") 

54 code.writeline("index,") 

55 

56 stride_args = ", ".join(f"inp_stride_{i}: int" for i in range(rank)) 

57 code.writeline(f"{stride_args}, # stride for inp") 

58 

59 stride_args = ", ".join(f"index_stride_{i}: int" for i in range(rank)) 

60 code.writeline(f"{stride_args}, # stride for index") 

61 

62 shape_args = ", ".join(f"index_shape_{i}: int" for i in range(rank)) 

63 code.writeline(f"{shape_args}, # shape for index") 

64 

65 code.writeline("dim,") 

66 code.writeline("stride_dim,") 

67 code.writeline("N,") 

68 code.writeline("BLOCK_SIZE: tl.constexpr,") 

69 code.writeline("):") 

70 

71 # Kernel Code 

72 with code.indent(): 

73 code.writeline("pid = tl.program_id(0)") 

74 code.writeline("offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)") 

75 code.writeline("mask = offsets < N") 

76 

77 # 1. Calculate inp_offsets and idx_offsets 

78 if large_input: 

79 code.writeline("inp_offsets = tl.zeros((BLOCK_SIZE,), dtype=tl.int64)") 

80 else: 

81 code.writeline("inp_offsets = tl.zeros((BLOCK_SIZE,), dtype=tl.int32)") 

82 code.writeline("index_offsets = offsets") 

83 

84 # 2. snippets 

85 for i in range(rank - 1, -1, -1): 

86 if not (dim == 0 and i == 0): 

87 code.writeline(f"mod = offsets % index_shape_{i}") 

88 

89 if i != dim: 

90 # will be corrected by adding cur_index*stride_dim 

91 code.writeline(f"inp_offsets += mod * inp_stride_{i}") 

92 if i != 0: 

93 code.writeline(f"offsets //= index_shape_{i}") 

94 

95 # Use offsets to gather 

96 if large_input: 

97 code.writeline( 

98 "cur_index = tl.load(index + index_offsets, mask=mask, other=0)" 

99 ) 

100 else: 

101 code.writeline( 

102 "cur_index = tl.load(index + index_offsets, mask=mask, other=0).to(tl.int32)" 

103 ) 

104 

105 code.writeline("inp_offsets += cur_index * stride_dim") 

106 

107 code.writeline("cur_inp = tl.load(inp + inp_offsets, mask=mask, other=0)") 

108 code.writeline("tl.store(out + index_offsets, cur_inp, mask=mask)") 

109 

110 code.newline() 

111 code.newline() 

112 return code 

113 

114 

115def parameter_for_wrapper() -> str: 

116 # inp_strided, out, index, dim, stride_dim, N 

117 parameters: List[str] = [] 

118 

119 parameters.append("inp_strided") 

120 parameters.append("out") 

121 parameters.append("index") 

122 parameters.append("dim") 

123 parameters.append("stride_dim") 

124 parameters.append("N") 

125 

126 return ", ".join(parameters) 

127 

128 

129def generate_gather_wrapper( 

130 rank: int, 

131 wrapper_name: str, 

132 kernel_name: str, 

133 code: IndentedBuffer, 

134) -> IndentedBuffer: 

135 parameters: str = parameter_for_wrapper() 

136 wrapper_signature: str = f"def {wrapper_name}({parameters}):" 

137 code.writeline(wrapper_signature) 

138 

139 with code.indent(): 

140 code.writeline("inp_strides = inp_strided.stride()") 

141 code.writeline("index_strides = index.stride()") 

142 code.writeline("index_shapes = list(index.shape)") 

143 

144 # kernel launch 

145 code.writeline("grid = lambda meta: (") 

146 with code.indent(): 

147 code.writeline('triton.cdiv(N, meta["BLOCK_SIZE"]),') 

148 code.writeline(")") 

149 

150 kernel_launch: str = f"{kernel_name}[grid](" 

151 code.writeline(kernel_launch) 

152 

153 with code.indent(): 

154 code.writeline("inp_strided, out, index, ") 

155 if rank > 0: 

156 s = ", ".join(f"inp_strides[{i}]" for i in range(rank)) 

157 code.writeline(f"{s},") 

158 

159 s = ", ".join(f"index_strides[{i}]" for i in range(rank)) 

160 code.writeline(f"{s},") 

161 

162 s = ", ".join(f"index_shapes[{i}]" for i in range(rank)) 

163 code.writeline(f"{s},") 

164 

165 code.writeline("dim,") 

166 code.writeline("stride_dim,") 

167 code.writeline("N,") 

168 code.writeline(")") 

169 code.writeline("return out") 

170 

171 return code 

172 

173 

174def generate_code( 

175 dim: int, 

176 large_input: bool, 

177 inputs: Tuple[Any], 

178 wrapper_name: str, 

179 kernel_name: str, 

180 code: IndentedBuffer, 

181) -> IndentedBuffer: 

182 # inputs: inp_strided, out, index, dim, stride_dim, N, large_input 

183 shape = inputs[2].shape 

184 rank = len(shape) 

185 

186 code = generate_imports(code) 

187 code = generate_gather_kernel(dim, large_input, rank, kernel_name, code) 

188 code = generate_gather_wrapper(rank, wrapper_name, kernel_name, code) 

189 return code 

190 

191 

192class GatherFunction: 

193 def __init__(self): 

194 self.pid = os.getpid() 

195 self.overloads: Mapping[str, Callable] = {} 

196 

197 def __call__(self, *args, **kwargs): 

198 rank = kwargs["rank"] 

199 dim = kwargs["dim"] 

200 large_input = kwargs["large_input"] 

201 

202 key = f"{self.arg_key(*args)}_{rank}_{dim}_{large_input}" 

203 if key in self.overloads: 

204 overload = self.overloads[key] 

205 else: 

206 code = IndentedBuffer() 

207 code = generate_code( 

208 dim, 

209 large_input, 

210 args, 

211 "_gather_wrapper", 

212 "_gather_jit_function", 

213 code, 

214 ) 

215 

216 file_name = f"gather_rank_{key}_pid_{self.pid}.py" 

217 

218 with open(code_cache_dir() / file_name, "wt", encoding="utf-8") as f: 

219 f.write(code.getvalue()) 

220 

221 # load 

222 spec = importlib.util.spec_from_file_location( 

223 f"_gen_module_rank_{key}_pid_{self.pid}", 

224 f.name, 

225 ) 

226 

227 m = importlib.util.module_from_spec(spec) 

228 spec.loader.exec_module(m) 

229 overload = getattr(m, "_gather_wrapper") 

230 self.overloads[key] = overload 

231 

232 return overload(*args) 

233 

234 def arg_key(self, *args): 

235 tensors = [item for item in args if torch.is_tensor(item)] 

236 max_rank = max(item.ndim for item in tensors) 

237 return max_rank 

238 

239 

240_gather_func = GatherFunction() 

241 

242 

243def gather(inp, dim, index, out=None, sparse_grad=False): 

244 logger.debug("GEMS_CAMBRICON GATHER") 

245 inp = inp.contiguous() 

246 index = index.contiguous() 

247 if out is None: 

248 out = torch.empty_like(index, dtype=inp.dtype, device=inp.device) 

249 out = out.contiguous() 

250 stride_dim = inp.stride(dim) 

251 

252 inp_strided = restride_dim(inp, dim, index.shape) 

253 N = index.numel() 

254 

255 large_input = inp.numel() * inp.element_size() > 2**31 

256 rank = len(index.shape) 

257 

258 # <rank>_<dim>_<large_input> is the key of overloads 

259 # large_input is only for key 

260 _gather_func( 

261 inp_strided, 

262 out, 

263 index, 

264 dim, 

265 stride_dim, 

266 N, 

267 large_input=large_input, 

268 dim=dim, 

269 rank=rank, 

270 ) 

271 return out 

272 

273 

274def gather_backward(grad, self, dim, index, sparse_grad): 

275 logger.debug("GEMS_CAMBRICON GATHER BACKWARD") 

276 result = grad.new_zeros(self.shape) 

277 return scatter(result, dim, index, grad, reduce="add")