Coverage for src/flag_gems/runtime/backend/_kunlunxin/fused/rotary_embedding.py: 0%
128 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1import logging
2from typing import Optional
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15@libentry()
16@triton.jit
17def apply_rotary_pos_emb_kernel(
18 oq_ptr,
19 ok_ptr,
20 q_ptr, # (n_tokens, q_heads, head_dim)
21 k_ptr, # (n_tokens, k_heads, head_dim)
22 cos_ptr, # (max_seq_len, dim // 2)
23 sin_ptr, # (max_seq_len, dim // 2)
24 pos_ptr, # (n_tokens, )
25 q_stride_s,
26 q_stride_h,
27 q_stride_d,
28 k_stride_s,
29 k_stride_h,
30 k_stride_d,
31 oq_stride_s,
32 oq_stride_h,
33 oq_stride_d,
34 ok_stride_s,
35 ok_stride_h,
36 ok_stride_d,
37 p_stride_s,
38 cos_stride_s,
39 sin_stride_s,
40 seq_len,
41 NUM_Q_HEADS: tl.constexpr,
42 NUM_K_HEADS: tl.constexpr,
43 HEAD_DIM: tl.constexpr,
44 PADDED_HEAD_DIM: tl.constexpr,
45 ROTARY_INTERLEAVED: tl.constexpr,
46 MAX_POSITION_EMBEDDINGS: tl.constexpr,
47):
48 s_id = tle.program_id(0)
50 if pos_ptr is None:
51 pos_id = s_id % seq_len
52 else:
53 pos_ptr += s_id * p_stride_s
54 pos_id = tl.load(pos_ptr)
55 cos_ptr += pos_id * cos_stride_s
56 sin_ptr += pos_id * sin_stride_s
58 # note: set TRITON_DEBUG=1 to enable this check
59 tl.device_assert(pos_id < MAX_POSITION_EMBEDDINGS, "position id out of bound")
61 ordered_block = tl.arange(0, PADDED_HEAD_DIM)
62 mask = ordered_block < HEAD_DIM
63 if ROTARY_INTERLEAVED:
64 odd_mask = ordered_block % 2 == 0
65 rotated_block = tl.where(odd_mask, ordered_block + 1, ordered_block - 1)
66 sin_cos_block = ordered_block // 2
67 cos = tl.load(cos_ptr + sin_cos_block, mask=mask, other=0.0).to(tl.float32)
68 sin = tl.load(sin_ptr + sin_cos_block, mask=mask, other=0.0).to(tl.float32)
69 sin = tl.where(odd_mask, -sin, sin)
70 else:
71 rotated_block = (ordered_block + HEAD_DIM // 2) % HEAD_DIM
72 sin_cos_block = ordered_block % (HEAD_DIM // 2)
73 cos = tl.load(cos_ptr + sin_cos_block, mask=mask, other=0.0).to(tl.float32)
74 sin = tl.load(sin_ptr + sin_cos_block, mask=mask, other=0.0).to(tl.float32)
75 sin = tl.where(rotated_block < HEAD_DIM // 2, sin, -sin)
77 oq_ptr += s_id * oq_stride_s
78 q_ptr += s_id * q_stride_s
80 for off_h in range(0, NUM_Q_HEADS):
81 ordered_cols = off_h * q_stride_h + (ordered_block * q_stride_d)
82 rotated_cols = off_h * q_stride_h + (rotated_block * q_stride_d)
83 output_offs = off_h * oq_stride_h + (ordered_block * oq_stride_d)
85 q = tl.load(q_ptr + ordered_cols, mask=mask, other=0.0)
86 rotated_q = tl.load(q_ptr + rotated_cols, mask=mask, other=0.0)
87 y = q * cos + rotated_q * sin
88 tl.store(oq_ptr + output_offs, y, mask=mask)
90 ok_ptr += s_id * ok_stride_s
91 k_ptr += s_id * k_stride_s
93 for off_h in range(0, NUM_K_HEADS):
94 ordered_cols = off_h * k_stride_h + (ordered_block * k_stride_d)
95 rotated_cols = off_h * k_stride_h + (rotated_block * k_stride_d)
96 output_offs = off_h * ok_stride_h + (ordered_block * ok_stride_d)
98 k = tl.load(k_ptr + ordered_cols, mask=mask, other=0.0)
99 rotated_k = tl.load(k_ptr + rotated_cols, mask=mask, other=0.0)
100 y = k * cos + rotated_k * sin
101 tl.store(ok_ptr + output_offs, y, mask=mask)
104@libentry()
105@triton.jit
106def apply_rotary_pos_emb_inplace_kernel(
107 q_ptr, # (n_tokens, q_heads, head_dim)
108 k_ptr, # (n_tokens, k_heads, head_dim)
109 cos_ptr, # (max_seq_len, dim // 2)
110 sin_ptr, # (max_seq_len, dim // 2)
111 pos_ptr, # (n_tokens, )
112 q_stride_s,
113 q_stride_h,
114 q_stride_d,
115 k_stride_s,
116 k_stride_h,
117 k_stride_d,
118 p_stride_s,
119 cos_stride_s,
120 sin_stride_s,
121 seq_len,
122 NUM_Q_HEADS: tl.constexpr,
123 NUM_K_HEADS: tl.constexpr,
124 HEAD_DIM: tl.constexpr,
125 PADDED_HEAD_DIM: tl.constexpr,
126 ROTARY_INTERLEAVED: tl.constexpr,
127 MAX_POSITION_EMBEDDINGS: tl.constexpr,
128):
129 s_id = tle.program_id(0)
131 if pos_ptr is None:
132 pos_id = s_id % seq_len
133 else:
134 pos_ptr += s_id * p_stride_s
135 pos_id = tl.load(pos_ptr)
136 cos_ptr += pos_id * cos_stride_s
137 sin_ptr += pos_id * sin_stride_s
139 # note: set TRITON_DEBUG=1 to enable this check
140 tl.device_assert(pos_id < MAX_POSITION_EMBEDDINGS, "position id out of bound")
142 ordered_block = tl.arange(0, PADDED_HEAD_DIM)
143 mask = ordered_block < HEAD_DIM
144 if ROTARY_INTERLEAVED:
145 odd_mask = ordered_block % 2 == 0
146 rotated_block = tl.where(odd_mask, ordered_block + 1, ordered_block - 1)
147 sin_cos_block = ordered_block // 2
148 cos = tl.load(cos_ptr + sin_cos_block, mask=mask, other=0.0).to(tl.float32)
149 sin = tl.load(sin_ptr + sin_cos_block, mask=mask, other=0.0).to(tl.float32)
150 sin = tl.where(odd_mask, -sin, sin)
151 else:
152 rotated_block = (ordered_block + HEAD_DIM // 2) % HEAD_DIM
153 sin_cos_block = ordered_block % (HEAD_DIM // 2)
154 cos = tl.load(cos_ptr + sin_cos_block, mask=mask, other=0.0).to(tl.float32)
155 sin = tl.load(sin_ptr + sin_cos_block, mask=mask, other=0.0).to(tl.float32)
156 sin = tl.where(rotated_block < HEAD_DIM // 2, sin, -sin)
158 q_ptr += s_id * q_stride_s
160 for off_h in range(0, NUM_Q_HEADS):
161 ordered_cols = off_h * q_stride_h + (ordered_block * q_stride_d)
162 rotated_cols = off_h * q_stride_h + (rotated_block * q_stride_d)
164 q = tl.load(q_ptr + ordered_cols, mask=mask, other=0.0)
165 rotated_q = tl.load(q_ptr + rotated_cols, mask=mask, other=0.0)
166 y = q * cos + rotated_q * sin
167 tl.store(q_ptr + ordered_cols, y, mask=mask) # In-place update
169 k_ptr += s_id * k_stride_s
171 for off_h in range(0, NUM_K_HEADS):
172 ordered_cols = off_h * k_stride_h + (ordered_block * k_stride_d)
173 rotated_cols = off_h * k_stride_h + (rotated_block * k_stride_d)
175 k = tl.load(k_ptr + ordered_cols, mask=mask, other=0.0)
176 rotated_k = tl.load(k_ptr + rotated_cols, mask=mask, other=0.0)
177 y = k * cos + rotated_k * sin
178 tl.store(k_ptr + ordered_cols, y, mask=mask) # In-place update
181def apply_rotary_pos_emb(
182 q,
183 k,
184 cos,
185 sin,
186 position_ids: Optional[torch.IntTensor] = None,
187 rotary_interleaved: bool = False,
188 inplace: bool = False,
189):
190 """
191 Apply rotary position embedding to q and k
193 Args:
194 q: (*, q_heads, head_dim)
195 k: (*, k_heads, head_dim)
196 cos: (max_seq_len, head_dim // 2)
197 sin: (max_seq_len, head_dim // 2)
198 position_ids: (*, ), optional, position ids for each token
199 rotary_interleaved: whether the head_dim is rotated in an interleaved way
201 Returns:
202 q_embed: (*, q_heads, head_dim)
203 k_embed: (*, k_heads, head_dim)
204 """
205 logger.debug("GEMS ROTARY_POS_EMBEDDING")
206 assert (
207 k.shape[-1] == q.shape[-1]
208 ), f"q and k must have the same last dimension, got {q.shape} and {k.shape}"
209 assert (
210 cos.shape[-1] == sin.shape[-1]
211 ), f"cos and sin must have the same last dimension, got {cos.shape} and {sin.shape}"
212 assert (
213 cos.shape[-1] * 2 == q.shape[-1]
214 ), f"cos/sin dim must be half of q/k dim, got {cos.shape} and {q.shape}"
215 assert cos.stride(-1) == 1, "cos must be contiguous at the last dimension"
216 assert sin.stride(-1) == 1, "sin must be contiguous at the last dimension"
218 q_shape = q.shape
219 k_shape = k.shape
221 assert (
222 q.shape[:-2] == k.shape[:-2]
223 ), f"q and k must have the same length, got {q.shape[:-2]} and {k.shape[:-2]}"
224 if position_ids is None:
225 assert (
226 len(q.shape) == 4
227 ), f"q must have 4 dimensions if position_ids is not provided, got {q.shape}"
228 seq_len = q.shape[-3]
229 else:
230 assert (
231 position_ids.shape == q.shape[:-2]
232 ), f"position_ids must have the same length as q, got {position_ids.shape} and {q.shape[:-2]}"
234 position_ids = position_ids.view(-1)
235 seq_len = None
237 q = q.view(-1, q.shape[-2], q.shape[-1])
238 k = k.view(-1, k.shape[-2], k.shape[-1])
240 n_tokens, q_heads, head_dim = q.shape
242 # The block size must be the next power of two, sometimes we need to pad it.
243 padded_head_dim = max(triton.next_power_of_2(head_dim), 16)
245 if inplace:
246 grid = (n_tokens,)
247 with torch_device_fn.device(q.device):
248 apply_rotary_pos_emb_inplace_kernel[grid](
249 q,
250 k,
251 cos,
252 sin,
253 position_ids,
254 q.stride(0),
255 q.stride(1),
256 q.stride(2),
257 k.stride(0),
258 k.stride(1),
259 k.stride(2),
260 position_ids.stride(0) if position_ids is not None else 0,
261 cos.stride(0),
262 sin.stride(0),
263 seq_len,
264 q.shape[-2],
265 k.shape[-2],
266 head_dim,
267 padded_head_dim,
268 rotary_interleaved,
269 MAX_POSITION_EMBEDDINGS=cos.shape[0],
270 isCloseUnrollControl=True,
271 )
272 return q.view(q_shape), k.view(k_shape)
273 # If not inplace, we need to create new tensors for q_embed and k_embed
274 else:
275 q_embed = torch.empty_like(q)
276 k_embed = torch.empty_like(k)
278 grid = (n_tokens,)
279 with torch_device_fn.device(q_embed.device):
280 apply_rotary_pos_emb_kernel[grid](
281 q_embed,
282 k_embed,
283 q,
284 k,
285 cos,
286 sin,
287 position_ids,
288 q.stride(0),
289 q.stride(1),
290 q.stride(2),
291 k.stride(0),
292 k.stride(1),
293 k.stride(2),
294 q_embed.stride(0),
295 q_embed.stride(1),
296 q_embed.stride(2),
297 k_embed.stride(0),
298 k_embed.stride(1),
299 k_embed.stride(2),
300 position_ids.stride(0) if position_ids is not None else 0,
301 cos.stride(0),
302 sin.stride(0),
303 seq_len,
304 q.shape[-2],
305 k.shape[-2],
306 head_dim,
307 padded_head_dim,
308 rotary_interleaved,
309 MAX_POSITION_EMBEDDINGS=cos.shape[0],
310 isCloseUnrollControl=True,
311 )
312 q_embed = q_embed.view(q_shape)
313 k_embed = k_embed.view(k_shape)
314 return q_embed, k_embed