Coverage for src/flag_gems/ops/vdot.py: 53%

140 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-11 02:28 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6from torch import Tensor 

7 

8from flag_gems import runtime 

9from flag_gems.utils import libentry 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14@triton.jit 

15def compute_vdot( 

16 inp_real, inp_imag, other_real, other_imag, inp_is_conj, other_is_conj 

17): 

18 # # Given inp storage: [inp_real, inp_imag], other: [other_real, other_imag] 

19 

20 # # Case 1: inp_is_conj = False, other_is_conj = False 

21 # out_real = inp_real * other_real + inp_imag * other_imag 

22 # out_imag = inp_real * other_imag - inp_imag * other_real 

23 

24 # # Case 2: inp_is_conj = True, other_is_conj = False 

25 # out_real = inp_real * other_real - inp_imag * other_imag 

26 # out_imag = inp_real * other_imag + inp_imag * other_real 

27 

28 # # Case 3: inp_is_conj = False, other_is_conj = True 

29 # out_real = inp_real * other_real - inp_imag * other_imag 

30 # out_imag = -inp_real * other_imag - inp_imag * other_real 

31 

32 # # Case 4: inp_is_conj = True, other_is_conj = True 

33 # out_real = inp_real * other_real + inp_imag * other_imag 

34 # out_imag = inp_real * other_imag - inp_imag * other_real 

35 if not inp_is_conj and not other_is_conj: # Case 1 

36 out_real = tl.sum(inp_real * other_real + inp_imag * other_imag) 

37 out_imag = tl.sum(inp_real * other_imag - inp_imag * other_real) 

38 elif inp_is_conj and not other_is_conj: # Case 2 

39 out_real = tl.sum(inp_real * other_real - inp_imag * other_imag) 

40 out_imag = tl.sum(inp_real * other_imag + inp_imag * other_real) 

41 elif not inp_is_conj and other_is_conj: # Case 3 

42 out_real = tl.sum(inp_real * other_real - inp_imag * other_imag) 

43 out_imag = tl.sum(-inp_real * other_imag - inp_imag * other_real) 

44 else: # Case 4 

45 out_real = tl.sum(inp_real * other_real + inp_imag * other_imag) 

46 out_imag = tl.sum(-inp_real * other_imag + inp_imag * other_real) 

47 

48 return out_real, out_imag 

49 

50 

51# support old version triton which do not support tl.split 

52@libentry() 

53@triton.jit() 

54def vdot_kernel_complex( 

55 inp_ptr, 

56 other_ptr, 

57 out_ptr, 

58 n_elements, 

59 inp_is_conj: tl.constexpr, 

60 other_is_conj: tl.constexpr, 

61 inp_stride: tl.constexpr, 

62 other_stride: tl.constexpr, 

63 BLOCK_SIZE: tl.constexpr, 

64): 

65 pid = tl.program_id(0) 

66 num_progs = tl.num_programs(0) 

67 

68 grid_stride = num_progs * BLOCK_SIZE 

69 

70 acc_real = tl.zeros([], dtype=tl.float32) 

71 acc_imag = tl.zeros([], dtype=tl.float32) 

72 

73 for current_start in range(0, n_elements // 2, grid_stride): 

74 complex_idx = current_start + pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

75 mask = complex_idx < n_elements // 2 

76 

77 real_offset = complex_idx * 2 

78 

79 inp_real = tl.load(inp_ptr + real_offset * inp_stride, mask=mask, other=0.0) 

80 inp_imag = tl.load(inp_ptr + real_offset * inp_stride + 1, mask=mask, other=0.0) 

81 

82 other_real = tl.load( 

83 other_ptr + real_offset * other_stride, mask=mask, other=0.0 

84 ) 

85 other_imag = tl.load( 

86 other_ptr + real_offset * other_stride + 1, mask=mask, other=0.0 

87 ) 

88 

89 out_real, out_imag = compute_vdot( 

90 inp_real, inp_imag, other_real, other_imag, inp_is_conj, other_is_conj 

91 ) 

92 acc_real += out_real 

93 acc_imag += out_imag 

94 

95 temp_offset = pid * 2 

96 tl.store(out_ptr + temp_offset, acc_real) 

97 tl.store(out_ptr + temp_offset + 1, acc_imag) 

98 

99 

100@libentry() 

101@triton.jit() 

102def reduce_kernel_complex(input_ptr, out_ptr, n_blocks, BLOCK_SIZE: tl.constexpr): 

103 pid = tl.program_id(0) 

104 base_offset = tl.arange(0, BLOCK_SIZE) 

105 mask = base_offset < n_blocks 

106 

107 inp_real = tl.load(input_ptr + base_offset * 2, mask=mask, other=0.0) 

108 inp_imag = tl.load(input_ptr + base_offset * 2 + 1, mask=mask, other=0.0) 

109 final_out_real = tl.sum(inp_real) 

110 final_out_imag = tl.sum(inp_imag) 

111 if pid == 0: 

112 tl.store(out_ptr, final_out_real) 

113 tl.store(out_ptr + 1, final_out_imag) 

114 

115 

116# only support real number 

117@libentry() 

118@triton.heuristics(runtime.get_heuristic_config("vdot")) 

119@triton.jit() 

120def dot_kernel( 

121 inp_ptr, 

122 other_ptr, 

123 out_ptr, 

124 n_elements, 

125 inp_stride: tl.constexpr, 

126 other_stride: tl.constexpr, 

127 BLOCK_SIZE: tl.constexpr, 

128): 

129 pid = tl.program_id(0) 

130 num_progs = tl.num_programs(0) 

131 grid_stride = num_progs * BLOCK_SIZE 

132 

133 acc = tl.zeros([], dtype=tl.float32) 

134 

135 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

136 

137 for current_start in range(0, n_elements, grid_stride): 

138 cur_offsets = current_start + offsets 

139 mask = cur_offsets < n_elements 

140 

141 inp = tl.load(inp_ptr + inp_stride * cur_offsets, mask=mask, other=0.0).to( 

142 tl.float32 

143 ) 

144 other = tl.load( 

145 other_ptr + other_stride * cur_offsets, mask=mask, other=0.0 

146 ).to(tl.float32) 

147 

148 acc += tl.sum(inp * other) 

149 

150 tl.store(out_ptr + pid, acc) 

151 

152 

153@libentry() 

154@triton.jit() 

155def reduce_kernel( 

156 partial_sums_ptr, 

157 output_ptr, 

158 n_blocks, 

159 BLOCK_SIZE: tl.constexpr, 

160): 

161 offset = tl.arange(0, BLOCK_SIZE) 

162 mask = offset < n_blocks 

163 

164 partial_sums = tl.load(partial_sums_ptr + offset, mask=mask, other=0.0) 

165 final_sum = tl.sum(partial_sums) 

166 

167 if tl.program_id(0) == 0: 

168 tl.store(output_ptr, final_sum) 

169 

170 

171@libentry() 

172@triton.heuristics(runtime.get_heuristic_config("vdot")) 

173@triton.jit() 

174def dot_kernel_fp32( 

175 inp_ptr, 

176 other_ptr, 

177 out_ptr, 

178 n_elements, 

179 inp_stride: tl.constexpr, 

180 other_stride: tl.constexpr, 

181 BLOCK_SIZE: tl.constexpr, 

182): 

183 pid = tl.program_id(0) 

184 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

185 mask = offset < n_elements 

186 

187 inp = tl.load(inp_ptr + inp_stride * offset, mask=mask) 

188 other = tl.load(other_ptr + other_stride * offset, mask=mask) 

189 

190 out = tl.sum(inp * other) 

191 tl.atomic_add(out_ptr, out) 

192 

193 

194def vdot(input: Tensor, other: Tensor): 

195 logger.debug("GEMS VDOT") 

196 

197 assert ( 

198 input.dtype == other.dtype 

199 ), f"Input tensors must have the same dtype. Got {input.dtype} and {other.dtype}." 

200 assert ( 

201 input.ndim == 1 and other.ndim == 1 

202 ), f"Input tensors must be 1D. Got {input.ndim}D and {other.ndim}D." 

203 assert ( 

204 input.size() == other.size() 

205 ), f"Input tensors must have the same size. Got {input.size()} and {other.size()}." 

206 

207 inp = input 

208 inp_stride = inp.stride()[0] 

209 other_stride = other.stride()[0] 

210 

211 if inp.is_complex(): 

212 inp_is_conj = False 

213 other_is_conj = False 

214 

215 if inp.is_conj(): 

216 inp_is_conj = True 

217 inp = inp.conj() 

218 

219 if other.is_conj(): 

220 other_is_conj = True 

221 other = other.conj() 

222 

223 inp_real = torch.view_as_real(inp) 

224 other_real = torch.view_as_real(other) 

225 

226 n_elements = inp_real.numel() 

227 n_complex = inp.numel() 

228 

229 block_size = runtime.get_heuristic_config("vdot")["BLOCK_SIZE"]( 

230 {"n_elements": n_elements} 

231 ) 

232 num_blocks = triton.cdiv(n_complex, block_size) 

233 

234 grid_size = min(num_blocks, 1024) 

235 

236 partial_real_sums = torch.empty( 

237 grid_size, dtype=inp_real.dtype, device=inp.device 

238 ) 

239 grid = (grid_size,) 

240 vdot_kernel_complex[grid]( 

241 inp_real, 

242 other_real, 

243 partial_real_sums, 

244 n_elements=n_elements, 

245 inp_is_conj=inp_is_conj, 

246 other_is_conj=other_is_conj, 

247 inp_stride=inp_stride, 

248 other_stride=other_stride, 

249 BLOCK_SIZE=block_size, 

250 ) 

251 output_real = torch.empty(2, dtype=inp_real.dtype, device=inp.device) 

252 reduce_kernel_complex[(1,)]( 

253 partial_real_sums, 

254 output_real, 

255 grid_size, 

256 BLOCK_SIZE=triton.next_power_of_2(grid_size), 

257 ) 

258 return torch.view_as_complex(output_real) 

259 elif inp.dtype == torch.float32: 

260 output = torch.zeros([], dtype=torch.float32, device=inp.device) 

261 n_elements = inp.numel() 

262 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

263 dot_kernel_fp32[grid]( 

264 inp, 

265 other, 

266 output, 

267 n_elements=n_elements, 

268 inp_stride=inp_stride, 

269 other_stride=other_stride, 

270 ) 

271 return output 

272 else: 

273 n_elements = inp.numel() 

274 block_size = runtime.get_heuristic_config("vdot")["BLOCK_SIZE"]( 

275 {"n_elements": n_elements} 

276 ) 

277 

278 num_blocks = triton.cdiv(n_elements, block_size) 

279 grid_size = min(num_blocks, 1024) 

280 

281 grid = (num_blocks,) 

282 partial_sums = torch.empty(grid_size, dtype=torch.float32, device=inp.device) 

283 dot_kernel[grid]( 

284 inp, 

285 other, 

286 partial_sums, 

287 n_elements=n_elements, 

288 inp_stride=inp_stride, 

289 other_stride=other_stride, 

290 BLOCK_SIZE=block_size, 

291 ) 

292 output = torch.empty([], dtype=input.dtype, device=inp.device) 

293 reduce_bs = min(triton.next_power_of_2(grid_size), 1024) 

294 reduce_kernel[(1,)]( 

295 partial_sums, 

296 output, 

297 num_blocks, 

298 BLOCK_SIZE=reduce_bs, 

299 ) 

300 return output