Coverage for src/flag_gems/runtime/backend/_ascend/fused/rotary_embedding.py: 0%

86 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-12 02:21 +0800

1import logging 

2from typing import Optional 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8import flag_gems 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15@triton.jit 

16def rotary_embedding_rw_kernel( 

17 state_out, 

18 state, 

19 cos, 

20 sin, 

21 stride_state_n, 

22 stride_state_h, 

23 stride_state_d, 

24 stride_cos_n, 

25 stride_cos_d, 

26 num_tokens, 

27 num_heads, 

28 token_range, 

29 head_range, 

30 dim_range_x, 

31 dim_range_y, 

32 rotary_interleaved: tl.constexpr, 

33): 

34 state_x_offset = ( 

35 token_range[:, None, None] * stride_state_n 

36 + head_range[None, :, None] * stride_state_h 

37 + dim_range_x[None, None, :] * stride_state_d 

38 ) 

39 state_y_offset = ( 

40 token_range[:, None, None] * stride_state_n 

41 + head_range[None, :, None] * stride_state_h 

42 + dim_range_y[None, None, :] * stride_state_d 

43 ) 

44 

45 cos_sim_offset = ( 

46 token_range[:, None, None] * stride_cos_n 

47 + dim_range_x[None, None, :] * stride_cos_d 

48 ) 

49 if rotary_interleaved: 

50 sin_sim_offset = ( 

51 token_range[:, None, None] * stride_cos_n 

52 + dim_range_y[None, None, :] * stride_cos_d 

53 ) 

54 else: 

55 sin_sim_offset = cos_sim_offset 

56 

57 state_x = tl.load( 

58 state + state_x_offset, 

59 mask=(token_range[:, None, None] < num_tokens) 

60 & (head_range[None, :, None] < num_heads), 

61 other=0.0, 

62 ) 

63 state_y = tl.load( 

64 state + state_y_offset, 

65 mask=(token_range[:, None, None] < num_tokens) 

66 & (head_range[None, :, None] < num_heads), 

67 other=0.0, 

68 ) 

69 

70 cos_loaded = tl.load( 

71 cos + cos_sim_offset, 

72 mask=token_range[:, None, None] < num_tokens, 

73 other=0.0, 

74 ).to(tl.float32) 

75 sin_loaded = tl.load( 

76 sin + sin_sim_offset, 

77 mask=token_range[:, None, None] < num_tokens, 

78 other=0.0, 

79 ).to(tl.float32) 

80 

81 out_x = state_x * cos_loaded - state_y * sin_loaded 

82 out_y = state_x * sin_loaded + state_y * cos_loaded 

83 

84 tl.store( 

85 state_out + state_x_offset, 

86 out_x, 

87 mask=(token_range[:, None, None] < num_tokens) 

88 & (head_range[None, :, None] < num_heads), 

89 ) 

90 tl.store( 

91 state_out + state_y_offset, 

92 out_y, 

93 mask=(token_range[:, None, None] < num_tokens) 

94 & (head_range[None, :, None] < num_heads), 

95 ) 

96 

97 

98@libentry() 

99@triton.jit 

100def rotary_embedding_siso_kernel( 

101 state_out, # [num_tokens, head_num, head_dim] 

102 state, # [num_tokens, head_num, head_dim] 

103 cos, # [num_tokens, 1, head_dim // 2] 

104 sin, # [num_tokens, 1, head_dim // 2] 

105 stride_state_n, 

106 stride_state_h, 

107 stride_state_d, 

108 stride_cos_n, 

109 stride_cos_d, 

110 num_tokens, 

111 num_heads, 

112 BLOCK_N: tl.constexpr, 

113 BLOCK_H: tl.constexpr, 

114 BLOCK_D: tl.constexpr, 

115 rotary_interleaved: tl.constexpr, 

116): 

117 token_index = tl.program_id(0) 

118 token_range = token_index * BLOCK_N + tl.arange(0, BLOCK_N) 

119 head_index = tl.program_id(1) 

120 head_range = head_index * BLOCK_H + tl.arange(0, BLOCK_H) 

121 

122 if rotary_interleaved: 

123 for d in range(0, BLOCK_D // 2): 

124 dim_range_x = d * 2 

125 dim_range_y = d * 2 + 1 

126 

127 rotary_embedding_rw_kernel( 

128 state_out, 

129 state, 

130 cos, 

131 sin, 

132 stride_state_n, 

133 stride_state_h, 

134 stride_state_d, 

135 stride_cos_n, 

136 stride_cos_d, 

137 num_tokens, 

138 num_heads, 

139 token_range, 

140 head_range, 

141 dim_range_x, 

142 dim_range_y, 

143 rotary_interleaved, 

144 ) 

145 else: 

146 dim_range_x = tl.arange(0, BLOCK_D // 2) 

147 dim_range_y = tl.arange(BLOCK_D // 2, BLOCK_D) 

148 rotary_embedding_rw_kernel( 

149 state_out, 

150 state, 

151 cos, 

152 sin, 

153 stride_state_n, 

154 stride_state_h, 

155 stride_state_d, 

156 stride_cos_n, 

157 stride_cos_d, 

158 num_tokens, 

159 num_heads, 

160 token_range, 

161 head_range, 

162 dim_range_x, 

163 dim_range_y, 

164 rotary_interleaved, 

165 ) 

166 

167 

168def apply_rotary_pos_emb( 

169 q, 

170 k, 

171 cos, 

172 sin, 

173 position_ids: Optional[torch.IntTensor] = None, 

174 rotary_interleaved: bool = False, 

175): 

176 """ 

177 Apply rotary position embedding to q and k 

178 

179 Args: 

180 q: (*, q_heads, head_dim) 

181 k: (*, k_heads, head_dim) 

182 cos: (max_seq_len, head_dim // 2) 

183 sin: (max_seq_len, head_dim // 2) 

184 position_ids: (*, ), optional, position ids for each token 

185 rotary_interleaved: whether the head_dim is rotated in an interleaved way 

186 

187 Returns: 

188 q_embed: (*, q_heads, head_dim) 

189 k_embed: (*, k_heads, head_dim) 

190 """ 

191 logger.debug("GEMS_ASCEND ROTARY POS EMBEDDING") 

192 assert ( 

193 k.shape[-1] == q.shape[-1] 

194 ), f"q and k must have the same last dimension, got {q.shape} and {k.shape}" 

195 assert ( 

196 cos.shape[-1] == sin.shape[-1] 

197 ), f"cos and sin must have the same last dimension, got {cos.shape} and {sin.shape}" 

198 assert ( 

199 cos.shape[-1] * 2 == q.shape[-1] 

200 ), f"cos/sin dim must be half of q/k dim, got {cos.shape} and {q.shape}" 

201 assert cos.stride(-1) == 1, "cos must be contiguous at the last dimension" 

202 assert sin.stride(-1) == 1, "sin must be contiguous at the last dimension" 

203 

204 q_shape = q.shape 

205 k_shape = k.shape 

206 

207 assert ( 

208 q.shape[:-2] == k.shape[:-2] 

209 ), f"q and k must have the same length, got {q.shape[:-2]} and {k.shape[:-2]}" 

210 if position_ids is None: 

211 assert ( 

212 len(q.shape) == 4 

213 ), f"q must have 4 dimensions if position_ids is not provided, got {q.shape}" 

214 else: 

215 assert ( 

216 position_ids.shape == q.shape[:-2] 

217 ), f"position_ids must have the same length as q, got {position_ids.shape} and {q.shape[:-2]}" 

218 

219 position_ids = position_ids.view(-1) 

220 

221 q = q.view(-1, q.shape[-2], q.shape[-1]) 

222 k = k.view(-1, k.shape[-2], k.shape[-1]) 

223 

224 q_embed = torch.empty_like(q) 

225 k_embed = torch.empty_like(k) 

226 

227 def torch_rotary_embedding(state_out, state, cos, sin): 

228 num_tokens = state.shape[0] 

229 num_heads = state.shape[1] 

230 head_dim = state.shape[-1] 

231 

232 BLOCK_N = 8 

233 BLOCK_H = 4 

234 grid = ( 

235 triton.cdiv(num_tokens, BLOCK_N), 

236 triton.cdiv(num_heads, BLOCK_H), 

237 ) 

238 with torch_device_fn.device(state_out.device): 

239 with flag_gems.use_gems(): 

240 if position_ids is None: 

241 cos = cos[: q_shape[-3], None, :] 

242 sin = sin[: q_shape[-3], None, :] 

243 else: 

244 cos = cos[position_ids, None, :] 

245 sin = sin[position_ids, None, :] 

246 

247 if rotary_interleaved: 

248 cos = torch.repeat_interleave(cos, 2, dim=-1) 

249 sin = torch.repeat_interleave(sin, 2, dim=-1) 

250 orig_cos = cos 

251 orig_sin = sin 

252 for _ in range(q_shape[0] - 1): 

253 cos = torch.cat((cos, orig_cos), dim=0) 

254 sin = torch.cat((sin, orig_sin), dim=0) 

255 rotary_embedding_siso_kernel[grid]( 

256 state_out, 

257 state, 

258 cos, 

259 sin, 

260 state.stride(0), 

261 state.stride(1), 

262 state.stride(2), 

263 cos.stride(0), 

264 cos.stride(2), 

265 num_tokens, 

266 num_heads, 

267 BLOCK_N=BLOCK_N, 

268 BLOCK_H=BLOCK_H, 

269 BLOCK_D=head_dim, 

270 rotary_interleaved=rotary_interleaved, 

271 ) 

272 

273 torch_rotary_embedding(q_embed, q, cos, sin) 

274 torch_rotary_embedding(k_embed, k, cos, sin) 

275 

276 q_embed = q_embed.view(q_shape) 

277 k_embed = k_embed.view(k_shape) 

278 return q_embed, k_embed