Coverage for src/flag_gems/ops/gather.py: 98%

130 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-26 15:32 +0800

1import importlib 

2import logging 

3import os 

4from typing import Any, Callable, Mapping, Tuple 

5 

6import torch 

7 

8from flag_gems.ops.scatter import scatter_ 

9from flag_gems.utils.code_cache import code_cache_dir 

10from flag_gems.utils.code_utils import IndentedBuffer, write_atomic 

11from flag_gems.utils.shape_utils import restride_dim 

12 

13logger = logging.getLogger(__name__) 

14 

15 

16def generate_imports(code: IndentedBuffer) -> IndentedBuffer: 

17 code.writeline("import torch") 

18 code.writeline("import triton") 

19 code.writeline("import triton.language as tl") 

20 code.newline() 

21 code.writeline("from flag_gems.utils import libentry") 

22 code.writeline("from flag_gems import runtime") 

23 code.writeline("from flag_gems.utils import triton_lang_extension as tle") 

24 

25 code.newline() 

26 code.newline() 

27 return code 

28 

29 

30def generate_gather_kernel( 

31 rank: int, 

32 kernel_name: str, 

33 code: IndentedBuffer, 

34) -> IndentedBuffer: 

35 # make the inlined function visible in the context 

36 code.newline() 

37 

38 code.writeline("@libentry()") 

39 code.writeline("@triton.heuristics({'BLOCK_SIZE_N': lambda args: 512})") 

40 code.writeline("@triton.jit") 

41 code.writeline(f"def {kernel_name}(") 

42 with code.indent(): 

43 args = [ 

44 "inp, ", 

45 "index, ", 

46 "out, ", 

47 ] 

48 args += [f"inp_shape{i}," for i in range(rank)] 

49 args += [f"index_shape{i}, " for i in range(rank)] 

50 args += [f"out_shape{i}, " for i in range(rank)] 

51 args += [f"inp_stride{i}, " for i in range(rank)] 

52 args += [f"index_stride{i}, " for i in range(rank)] 

53 args += [f"out_stride{i}, " for i in range(rank)] 

54 args += ["dim, ", "dim_stride, ", "N, ", "BLOCK_SIZE_N: tl.constexpr, "] 

55 code.writelines(args) 

56 code.writeline("):") 

57 

58 with code.indent(): 

59 code.writeline("pid = tle.program_id(0)") 

60 code.writeline( 

61 "offset = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)" 

62 ) 

63 code.newline() 

64 code.writeline("cur_offset = offset") 

65 for i in range(rank - 1, -1, -1): 

66 code.writeline(f"index_idx{i} = cur_offset % index_shape{i}") 

67 code.writeline(f"cur_offset = cur_offset // index_shape{i}") 

68 code.newline() 

69 comp = [f"index_idx{i} * index_stride{i}" for i in range(rank)] 

70 code.writeline(f"index_offset = {' + '.join(comp)}") 

71 code.writeline("mask = offset < N") 

72 code.writeline("cur_index = tl.load(index + index_offset, mask=mask, other=0)") 

73 code.newline() 

74 comp = [f"index_idx{i} * inp_stride{i}" for i in range(rank)] 

75 code.writeline(f"inp_offset = {' + '.join(comp)}") 

76 code.writeline("inp_offset += cur_index * dim_stride") 

77 code.writeline("cur_inp = tl.load(inp + inp_offset, mask=mask, other=0)") 

78 code.newline() 

79 comp = [f"index_idx{i} * out_stride{i}" for i in range(rank)] 

80 code.writeline(f"out_offset = {' + '.join(comp)}") 

81 code.writeline("tl.store(out + out_offset, value=cur_inp, mask=mask)") 

82 

83 code.newline() 

84 code.newline() 

85 return code 

86 

87 

88def generate_gather_wrapper( 

89 rank: int, 

90 wrapper_name: str, 

91 kernel_name: str, 

92 code: IndentedBuffer, 

93) -> IndentedBuffer: 

94 code.writeline(f"def {wrapper_name}(inp, dim, index, out, dim_stride, N):") 

95 with code.indent(): 

96 code.writeline("inp_shape = inp.shape") 

97 code.writeline("inp_stride = inp.stride()") 

98 code.writeline("index_shape = index.shape") 

99 code.writeline("index_stride = index.stride()") 

100 code.writeline("out_shape = out.shape") 

101 code.writeline("out_stride = out.stride()") 

102 code.writeline("grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE_N']), )") 

103 code.writeline(f"{kernel_name}[grid](") 

104 with code.indent(): 

105 args = [ 

106 "inp, ", 

107 "index, ", 

108 "out, ", 

109 ] 

110 args += [f"inp_shape[{i}], " for i in range(rank)] 

111 args += [f"index_shape[{i}], " for i in range(rank)] 

112 args += [f"out_shape[{i}], " for i in range(rank)] 

113 args += [f"inp_stride[{i}], " for i in range(rank)] 

114 args += [f"index_stride[{i}], " for i in range(rank)] 

115 args += [f"out_stride[{i}], " for i in range(rank)] 

116 args += [ 

117 "dim, ", 

118 "dim_stride, ", 

119 "N, ", 

120 ] 

121 code.writelines(args) 

122 code.writeline(")") 

123 code.writeline("return out") 

124 code.newline() 

125 code.newline() 

126 return code 

127 

128 

129def generate_code( 

130 inputs: Tuple[Any], 

131 wrapper_name: str, 

132 kernel_name: str, 

133 code: IndentedBuffer, 

134) -> IndentedBuffer: 

135 rank = inputs[0].ndim 

136 

137 code = generate_imports(code) 

138 code = generate_gather_kernel(rank, kernel_name, code) 

139 code = generate_gather_wrapper(rank, wrapper_name, kernel_name, code) 

140 return code 

141 

142 

143class GatherFunction: 

144 def __init__(self): 

145 self.pid = os.getpid() 

146 self.overloads: Mapping[str, Callable] = {} 

147 

148 def __call__(self, *args, **kwargs): 

149 key = f"{self.arg_key(*args)}" 

150 if key in self.overloads: 

151 overload = self.overloads[key] 

152 else: 

153 code = IndentedBuffer() 

154 code = generate_code( 

155 args, 

156 "_gather_wrapper", 

157 "_gather_flaggems_jit_function", 

158 code, 

159 ) 

160 

161 file_name = f"gather_rank_{key}.py" 

162 file_path = code_cache_dir() / file_name 

163 write_atomic(file_path, code.getvalue()) 

164 

165 # load 

166 spec = importlib.util.spec_from_file_location( 

167 f"_gen_module_rank_{key}", 

168 file_path, 

169 ) 

170 

171 m = importlib.util.module_from_spec(spec) 

172 spec.loader.exec_module(m) 

173 overload = getattr(m, "_gather_wrapper") 

174 self.overloads[key] = overload 

175 

176 return overload(*args, **kwargs) 

177 

178 def arg_key(self, *args): 

179 return args[0].ndim 

180 

181 

182_gather_func = GatherFunction() 

183 

184 

185def gather(inp, dim, index, out=None, sparse_grad=False): 

186 logger.debug("GEMS GATHER") 

187 if out is None: 

188 out = torch.empty_like(index, dtype=inp.dtype, device=inp.device) 

189 dim_stride = inp.stride(dim) 

190 inp_strided = restride_dim(inp, dim, index.shape) 

191 N = index.numel() 

192 _gather_func(inp_strided, dim, index, out, dim_stride, N) 

193 return out 

194 

195 

196def gather_backward(grad, self, dim, index, sparse_grad): 

197 logger.debug("GEMS GATHER BACKWARD") 

198 result = grad.new_zeros(self.shape) 

199 return scatter_(result, dim, index, grad, reduce="add")