Coverage for src/flag_gems/fused/rotary_embedding.py: 34%

128 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-23 02:03 +0800

1import logging 

2from typing import Optional 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15@libentry() 

16@triton.jit 

17def apply_rotary_pos_emb_kernel( 

18 oq_ptr, 

19 ok_ptr, 

20 q_ptr, # (n_tokens, q_heads, head_dim) 

21 k_ptr, # (n_tokens, k_heads, head_dim) 

22 cos_ptr, # (max_seq_len, dim // 2) 

23 sin_ptr, # (max_seq_len, dim // 2) 

24 pos_ptr, # (n_tokens, ) 

25 q_stride_s, 

26 q_stride_h, 

27 q_stride_d, 

28 k_stride_s, 

29 k_stride_h, 

30 k_stride_d, 

31 oq_stride_s, 

32 oq_stride_h, 

33 oq_stride_d, 

34 ok_stride_s, 

35 ok_stride_h, 

36 ok_stride_d, 

37 p_stride_s, 

38 cos_stride_s, 

39 sin_stride_s, 

40 seq_len, 

41 NUM_Q_HEADS: tl.constexpr, 

42 NUM_K_HEADS: tl.constexpr, 

43 HEAD_DIM: tl.constexpr, 

44 PADDED_HEAD_DIM: tl.constexpr, 

45 ROTARY_INTERLEAVED: tl.constexpr, 

46 MAX_POSITION_EMBEDDINGS: tl.constexpr, 

47): 

48 s_id = tle.program_id(0) 

49 

50 if pos_ptr is None: 

51 pos_id = s_id % seq_len 

52 else: 

53 pos_ptr += s_id * p_stride_s 

54 pos_id = tl.load(pos_ptr) 

55 cos_ptr += pos_id * cos_stride_s 

56 sin_ptr += pos_id * sin_stride_s 

57 

58 # note: set TRITON_DEBUG=1 to enable this check 

59 tl.device_assert(pos_id < MAX_POSITION_EMBEDDINGS, "position id out of bound") 

60 

61 ordered_block = tl.arange(0, PADDED_HEAD_DIM) 

62 mask = ordered_block < HEAD_DIM 

63 if ROTARY_INTERLEAVED: 

64 odd_mask = ordered_block % 2 == 0 

65 rotated_block = tl.where(odd_mask, ordered_block + 1, ordered_block - 1) 

66 sin_cos_block = ordered_block // 2 

67 cos = tl.load(cos_ptr + sin_cos_block, mask=mask, other=0.0).to(tl.float32) 

68 sin = tl.load(sin_ptr + sin_cos_block, mask=mask, other=0.0).to(tl.float32) 

69 sin = tl.where(odd_mask, -sin, sin) 

70 else: 

71 rotated_block = (ordered_block + HEAD_DIM // 2) % HEAD_DIM 

72 sin_cos_block = ordered_block % (HEAD_DIM // 2) 

73 cos = tl.load(cos_ptr + sin_cos_block, mask=mask, other=0.0).to(tl.float32) 

74 sin = tl.load(sin_ptr + sin_cos_block, mask=mask, other=0.0).to(tl.float32) 

75 sin = tl.where(rotated_block < HEAD_DIM // 2, sin, -sin) 

76 

77 oq_ptr += s_id * oq_stride_s 

78 q_ptr += s_id * q_stride_s 

79 

80 for off_h in range(0, NUM_Q_HEADS): 

81 ordered_cols = off_h * q_stride_h + (ordered_block * q_stride_d) 

82 rotated_cols = off_h * q_stride_h + (rotated_block * q_stride_d) 

83 output_offs = off_h * oq_stride_h + (ordered_block * oq_stride_d) 

84 

85 q = tl.load(q_ptr + ordered_cols, mask=mask, other=0.0) 

86 rotated_q = tl.load(q_ptr + rotated_cols, mask=mask, other=0.0) 

87 y = q * cos + rotated_q * sin 

88 tl.store(oq_ptr + output_offs, y, mask=mask) 

89 

90 ok_ptr += s_id * ok_stride_s 

91 k_ptr += s_id * k_stride_s 

92 

93 for off_h in range(0, NUM_K_HEADS): 

94 ordered_cols = off_h * k_stride_h + (ordered_block * k_stride_d) 

95 rotated_cols = off_h * k_stride_h + (rotated_block * k_stride_d) 

96 output_offs = off_h * ok_stride_h + (ordered_block * ok_stride_d) 

97 

98 k = tl.load(k_ptr + ordered_cols, mask=mask, other=0.0) 

99 rotated_k = tl.load(k_ptr + rotated_cols, mask=mask, other=0.0) 

100 y = k * cos + rotated_k * sin 

101 tl.store(ok_ptr + output_offs, y, mask=mask) 

102 

103 

104@libentry() 

105@triton.jit 

106def apply_rotary_pos_emb_inplace_kernel( 

107 q_ptr, # (n_tokens, q_heads, head_dim) 

108 k_ptr, # (n_tokens, k_heads, head_dim) 

109 cos_ptr, # (max_seq_len, dim // 2) 

110 sin_ptr, # (max_seq_len, dim // 2) 

111 pos_ptr, # (n_tokens, ) 

112 q_stride_s, 

113 q_stride_h, 

114 q_stride_d, 

115 k_stride_s, 

116 k_stride_h, 

117 k_stride_d, 

118 p_stride_s, 

119 cos_stride_s, 

120 sin_stride_s, 

121 seq_len, 

122 NUM_Q_HEADS: tl.constexpr, 

123 NUM_K_HEADS: tl.constexpr, 

124 HEAD_DIM: tl.constexpr, 

125 PADDED_HEAD_DIM: tl.constexpr, 

126 ROTARY_INTERLEAVED: tl.constexpr, 

127 MAX_POSITION_EMBEDDINGS: tl.constexpr, 

128): 

129 s_id = tle.program_id(0) 

130 

131 if pos_ptr is None: 

132 pos_id = s_id % seq_len 

133 else: 

134 pos_ptr += s_id * p_stride_s 

135 pos_id = tl.load(pos_ptr) 

136 cos_ptr += pos_id * cos_stride_s 

137 sin_ptr += pos_id * sin_stride_s 

138 

139 # note: set TRITON_DEBUG=1 to enable this check 

140 tl.device_assert(pos_id < MAX_POSITION_EMBEDDINGS, "position id out of bound") 

141 

142 ordered_block = tl.arange(0, PADDED_HEAD_DIM) 

143 mask = ordered_block < HEAD_DIM 

144 if ROTARY_INTERLEAVED: 

145 odd_mask = ordered_block % 2 == 0 

146 rotated_block = tl.where(odd_mask, ordered_block + 1, ordered_block - 1) 

147 sin_cos_block = ordered_block // 2 

148 cos = tl.load(cos_ptr + sin_cos_block, mask=mask, other=0.0).to(tl.float32) 

149 sin = tl.load(sin_ptr + sin_cos_block, mask=mask, other=0.0).to(tl.float32) 

150 sin = tl.where(odd_mask, -sin, sin) 

151 else: 

152 rotated_block = (ordered_block + HEAD_DIM // 2) % HEAD_DIM 

153 sin_cos_block = ordered_block % (HEAD_DIM // 2) 

154 cos = tl.load(cos_ptr + sin_cos_block, mask=mask, other=0.0).to(tl.float32) 

155 sin = tl.load(sin_ptr + sin_cos_block, mask=mask, other=0.0).to(tl.float32) 

156 sin = tl.where(rotated_block < HEAD_DIM // 2, sin, -sin) 

157 

158 q_ptr += s_id * q_stride_s 

159 

160 for off_h in range(0, NUM_Q_HEADS): 

161 ordered_cols = off_h * q_stride_h + (ordered_block * q_stride_d) 

162 rotated_cols = off_h * q_stride_h + (rotated_block * q_stride_d) 

163 

164 q = tl.load(q_ptr + ordered_cols, mask=mask, other=0.0) 

165 rotated_q = tl.load(q_ptr + rotated_cols, mask=mask, other=0.0) 

166 y = q * cos + rotated_q * sin 

167 tl.store(q_ptr + ordered_cols, y, mask=mask) # In-place update 

168 

169 k_ptr += s_id * k_stride_s 

170 

171 for off_h in range(0, NUM_K_HEADS): 

172 ordered_cols = off_h * k_stride_h + (ordered_block * k_stride_d) 

173 rotated_cols = off_h * k_stride_h + (rotated_block * k_stride_d) 

174 

175 k = tl.load(k_ptr + ordered_cols, mask=mask, other=0.0) 

176 rotated_k = tl.load(k_ptr + rotated_cols, mask=mask, other=0.0) 

177 y = k * cos + rotated_k * sin 

178 tl.store(k_ptr + ordered_cols, y, mask=mask) # In-place update 

179 

180 

181def apply_rotary_pos_emb( 

182 q, 

183 k, 

184 cos, 

185 sin, 

186 position_ids: Optional[torch.IntTensor] = None, 

187 rotary_interleaved: bool = False, 

188 inplace: bool = False, 

189): 

190 """ 

191 Apply rotary position embedding to q and k 

192 

193 Args: 

194 q: (*, q_heads, head_dim) 

195 k: (*, k_heads, head_dim) 

196 cos: (max_seq_len, head_dim // 2) 

197 sin: (max_seq_len, head_dim // 2) 

198 position_ids: (*, ), optional, position ids for each token 

199 rotary_interleaved: whether the head_dim is rotated in an interleaved way 

200 

201 Returns: 

202 q_embed: (*, q_heads, head_dim) 

203 k_embed: (*, k_heads, head_dim) 

204 """ 

205 logger.debug("GEMS ROTARY_POS_EMBEDDING") 

206 assert ( 

207 k.shape[-1] == q.shape[-1] 

208 ), f"q and k must have the same last dimension, got {q.shape} and {k.shape}" 

209 assert ( 

210 cos.shape[-1] == sin.shape[-1] 

211 ), f"cos and sin must have the same last dimension, got {cos.shape} and {sin.shape}" 

212 assert ( 

213 cos.shape[-1] * 2 == q.shape[-1] 

214 ), f"cos/sin dim must be half of q/k dim, got {cos.shape} and {q.shape}" 

215 assert cos.stride(-1) == 1, "cos must be contiguous at the last dimension" 

216 assert sin.stride(-1) == 1, "sin must be contiguous at the last dimension" 

217 

218 q_shape = q.shape 

219 k_shape = k.shape 

220 

221 assert ( 

222 q.shape[:-2] == k.shape[:-2] 

223 ), f"q and k must have the same length, got {q.shape[:-2]} and {k.shape[:-2]}" 

224 if position_ids is None: 

225 assert ( 

226 len(q.shape) == 4 

227 ), f"q must have 4 dimensions if position_ids is not provided, got {q.shape}" 

228 seq_len = q.shape[-3] 

229 else: 

230 assert ( 

231 position_ids.shape == q.shape[:-2] 

232 ), f"position_ids must have the same length as q, got {position_ids.shape} and {q.shape[:-2]}" 

233 

234 position_ids = position_ids.view(-1) 

235 seq_len = None 

236 

237 q = q.view(-1, q.shape[-2], q.shape[-1]) 

238 k = k.view(-1, k.shape[-2], k.shape[-1]) 

239 

240 n_tokens, q_heads, head_dim = q.shape 

241 

242 # The block size must be the next power of two, sometimes we need to pad it. 

243 padded_head_dim = max(triton.next_power_of_2(head_dim), 16) 

244 

245 if inplace: 

246 grid = (n_tokens,) 

247 with torch_device_fn.device(q.device): 

248 apply_rotary_pos_emb_inplace_kernel[grid]( 

249 q, 

250 k, 

251 cos, 

252 sin, 

253 position_ids, 

254 q.stride(0), 

255 q.stride(1), 

256 q.stride(2), 

257 k.stride(0), 

258 k.stride(1), 

259 k.stride(2), 

260 position_ids.stride(0) if position_ids is not None else 0, 

261 cos.stride(0), 

262 sin.stride(0), 

263 seq_len, 

264 q.shape[-2], 

265 k.shape[-2], 

266 head_dim, 

267 padded_head_dim, 

268 rotary_interleaved, 

269 MAX_POSITION_EMBEDDINGS=cos.shape[0], 

270 ) 

271 return q.view(q_shape), k.view(k_shape) 

272 # If not inplace, we need to create new tensors for q_embed and k_embed 

273 else: 

274 q_embed = torch.empty_like(q) 

275 k_embed = torch.empty_like(k) 

276 

277 grid = (n_tokens,) 

278 with torch_device_fn.device(q_embed.device): 

279 apply_rotary_pos_emb_kernel[grid]( 

280 q_embed, 

281 k_embed, 

282 q, 

283 k, 

284 cos, 

285 sin, 

286 position_ids, 

287 q.stride(0), 

288 q.stride(1), 

289 q.stride(2), 

290 k.stride(0), 

291 k.stride(1), 

292 k.stride(2), 

293 q_embed.stride(0), 

294 q_embed.stride(1), 

295 q_embed.stride(2), 

296 k_embed.stride(0), 

297 k_embed.stride(1), 

298 k_embed.stride(2), 

299 position_ids.stride(0) if position_ids is not None else 0, 

300 cos.stride(0), 

301 sin.stride(0), 

302 seq_len, 

303 q.shape[-2], 

304 k.shape[-2], 

305 head_dim, 

306 padded_head_dim, 

307 rotary_interleaved, 

308 MAX_POSITION_EMBEDDINGS=cos.shape[0], 

309 ) 

310 q_embed = q_embed.view(q_shape) 

311 k_embed = k_embed.view(k_shape) 

312 return q_embed, k_embed