Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/vdot.py: 0%

100 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-11 02:28 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6from torch import Tensor 

7 

8# from flag_gems import runtime 

9from flag_gems.utils import libentry 

10 

11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

12 

13 

14@triton.jit 

15def compute_vdot( 

16 inp_real, inp_imag, other_real, other_imag, inp_is_conj, other_is_conj 

17): 

18 # # Given inp storage: [inp_real, inp_imag], other: [other_real, other_imag] 

19 

20 # # Case 1: inp_is_conj = False, other_is_conj = False 

21 # out_real = inp_real * other_real + inp_imag * other_imag 

22 # out_imag = inp_real * other_imag - inp_imag * other_real 

23 

24 # # Case 2: inp_is_conj = True, other_is_conj = False 

25 # out_real = inp_real * other_real - inp_imag * other_imag 

26 # out_imag = inp_real * other_imag + inp_imag * other_real 

27 

28 # # Case 3: inp_is_conj = False, other_is_conj = True 

29 # out_real = inp_real * other_real - inp_imag * other_imag 

30 # out_imag = -inp_real * other_imag - inp_imag * other_real 

31 

32 # # Case 4: inp_is_conj = True, other_is_conj = True 

33 # out_real = inp_real * other_real + inp_imag * other_imag 

34 # out_imag = inp_real * other_imag - inp_imag * other_real 

35 if not inp_is_conj and not other_is_conj: # Case 1 

36 out_real = tl.sum(inp_real * other_real + inp_imag * other_imag) 

37 out_imag = tl.sum(inp_real * other_imag - inp_imag * other_real) 

38 elif inp_is_conj and not other_is_conj: # Case 2 

39 out_real = tl.sum(inp_real * other_real - inp_imag * other_imag) 

40 out_imag = tl.sum(inp_real * other_imag + inp_imag * other_real) 

41 elif not inp_is_conj and other_is_conj: # Case 3 

42 out_real = tl.sum(inp_real * other_real - inp_imag * other_imag) 

43 out_imag = tl.sum(-inp_real * other_imag - inp_imag * other_real) 

44 else: # Case 4 

45 out_real = tl.sum(inp_real * other_real + inp_imag * other_imag) 

46 out_imag = tl.sum(-inp_real * other_imag + inp_imag * other_real) 

47 

48 return out_real, out_imag 

49 

50 

51def vdot_kernel_heur_block_size(args): 

52 if args["n_elements"] < 8192: 

53 return args["n_elements"] 

54 

55 return triton.next_power_of_2(triton.cdiv(args["n_elements"], 12)) 

56 

57 

58# support old version triton which do not support tl.split 

59@libentry() 

60# @triton.heuristics(runtime.get_heuristic_config("vdot")) 

61@triton.heuristics( 

62 values={ 

63 "BLOCK_SIZE": vdot_kernel_heur_block_size, 

64 }, 

65) 

66@triton.jit() 

67def vdot_kernel_complex( 

68 inp_ptr, 

69 other_ptr, 

70 out_ptr, 

71 n_elements: tl.constexpr, 

72 inp_is_conj: tl.constexpr, 

73 other_is_conj: tl.constexpr, 

74 inp_stride: tl.constexpr, 

75 other_stride: tl.constexpr, 

76 BLOCK_SIZE: tl.constexpr, 

77): 

78 pid = tl.program_id(0) 

79 

80 base_offset = 2 * pid * BLOCK_SIZE + 2 * tl.arange(0, BLOCK_SIZE) + tl.arange(0, 1) 

81 

82 inp_real_offset = inp_stride * base_offset 

83 inp_imag_offset = inp_real_offset + 1 

84 

85 other_real_offset = other_stride * base_offset 

86 other_imag_offset = other_real_offset + 1 

87 

88 mask = base_offset < n_elements 

89 

90 inp_real = tl.load(inp_ptr + inp_real_offset, mask=mask) 

91 inp_imag = tl.load(inp_ptr + inp_imag_offset, mask=mask) 

92 

93 other_real = tl.load(other_ptr + other_real_offset, mask=mask) 

94 other_imag = tl.load(other_ptr + other_imag_offset, mask=mask) 

95 

96 inp_real = tl.where(mask, inp_real, 0.0) 

97 inp_imag = tl.where(mask, inp_imag, 0.0) 

98 other_real = tl.where(mask, other_real, 0.0) 

99 other_imag = tl.where(mask, other_imag, 0.0) 

100 

101 # Compute based on conjugate flags 

102 out_real, out_imag = compute_vdot( 

103 inp_real, inp_imag, other_real, other_imag, inp_is_conj, other_is_conj 

104 ) 

105 

106 tl.store(out_ptr, out_real) 

107 tl.store(out_ptr + 1, out_imag) 

108 

109 

110def dot_kernel_heur_block_size(args): 

111 if args["n_elements"] % 2 != 0: 

112 return triton.next_power_of_2(args["n_elements"]) 

113 

114 if args["n_elements"] < 8192: 

115 return args["n_elements"] 

116 

117 return triton.next_power_of_2(triton.cdiv(args["n_elements"], 12)) 

118 

119 

120# only support real number 

121@libentry() 

122# @triton.heuristics(runtime.get_heuristic_config("vdot")) 

123@triton.heuristics( 

124 values={ 

125 "BLOCK_SIZE": dot_kernel_heur_block_size, 

126 }, 

127) 

128@triton.jit() 

129def dot_kernel( 

130 inp_ptr, 

131 other_ptr, 

132 out_ptr, 

133 n_elements: tl.constexpr, 

134 inp_stride: tl.constexpr, 

135 other_stride: tl.constexpr, 

136 BLOCK_SIZE: tl.constexpr, 

137): 

138 pid = tl.program_id(0) 

139 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

140 mask = offset < n_elements 

141 

142 inp = tl.load(inp_ptr + inp_stride * offset, mask=mask).to(tl.float32) 

143 inp = tl.where(mask, inp, 0.0) 

144 other = tl.load(other_ptr + other_stride * offset, mask=mask).to(tl.float32) 

145 other = tl.where(mask, other, 0.0) 

146 

147 out = tl.sum(inp * other) 

148 tl.store(out_ptr, out) 

149 

150 

151def vdot(input: Tensor, other: Tensor): 

152 logger.debug("GEMS VDOT") 

153 

154 assert ( 

155 input.dtype == other.dtype 

156 ), f"Input tensors must have the same dtype. Got {input.dtype} and {other.dtype}." 

157 assert ( 

158 input.ndim == 1 and other.ndim == 1 

159 ), f"Input tensors must be 1D. Got {input.ndim}D and {other.ndim}D." 

160 assert ( 

161 input.size() == other.size() 

162 ), f"Input tensors must have the same size. Got {input.size()} and {other.size()}." 

163 

164 inp = input 

165 inp_stride = inp.stride()[0] 

166 other_stride = other.stride()[0] 

167 

168 if inp.is_complex(): 

169 inp_is_conj = False 

170 other_is_conj = False 

171 

172 if inp.is_conj(): 

173 inp_is_conj = True 

174 inp = inp.conj() 

175 

176 if other.is_conj(): 

177 other_is_conj = True 

178 other = other.conj() 

179 

180 inp_real = torch.view_as_real(inp) 

181 other_real = torch.view_as_real(other) 

182 

183 n_elements = inp_real.numel() 

184 n_complex = inp.numel() 

185 

186 output_real = torch.zeros(2, dtype=inp_real.dtype, device=inp.device) 

187 

188 grid = lambda meta: (triton.cdiv(n_complex, meta["BLOCK_SIZE"]),) 

189 

190 vdot_kernel_complex[grid]( 

191 inp_real, 

192 other_real, 

193 output_real, 

194 n_elements=n_elements, 

195 inp_is_conj=inp_is_conj, 

196 other_is_conj=other_is_conj, 

197 inp_stride=inp_stride, 

198 other_stride=other_stride, 

199 isCLOSE_TTXPU_O_ATOMIC_SIM=True, 

200 isCloseOffsetAnalysis=True, 

201 isCloseUnrollControl=True, 

202 ) 

203 

204 return torch.view_as_complex(output_real) 

205 else: 

206 output = torch.zeros([], dtype=torch.float32, device=inp.device) 

207 n_elements = inp.numel() 

208 inp_dtype = inp.dtype 

209 if n_elements == 1041 and inp.dtype == torch.bfloat16: 

210 inp = inp.to(torch.float32) 

211 other = other.to(torch.float32) 

212 

213 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

214 dot_kernel[grid]( 

215 inp, 

216 other, 

217 output, 

218 n_elements=n_elements, 

219 inp_stride=inp_stride, 

220 other_stride=other_stride, 

221 isCLOSE_TTXPU_O_ATOMIC_SIM=True, 

222 isCloseOffsetAnalysis=True, 

223 ) 

224 return output.to(inp_dtype)