Coverage for src/flag_gems/runtime/backend/_ascend/fused/rotary_embedding.py: 0%
86 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1import logging
2from typing import Optional
4import torch
5import triton
6import triton.language as tl
8import flag_gems
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
12logger = logging.getLogger(__name__)
15@triton.jit
16def rotary_embedding_rw_kernel(
17 state_out,
18 state,
19 cos,
20 sin,
21 stride_state_n,
22 stride_state_h,
23 stride_state_d,
24 stride_cos_n,
25 stride_cos_d,
26 num_tokens,
27 num_heads,
28 token_range,
29 head_range,
30 dim_range_x,
31 dim_range_y,
32 rotary_interleaved: tl.constexpr,
33):
34 state_x_offset = (
35 token_range[:, None, None] * stride_state_n
36 + head_range[None, :, None] * stride_state_h
37 + dim_range_x[None, None, :] * stride_state_d
38 )
39 state_y_offset = (
40 token_range[:, None, None] * stride_state_n
41 + head_range[None, :, None] * stride_state_h
42 + dim_range_y[None, None, :] * stride_state_d
43 )
45 cos_sim_offset = (
46 token_range[:, None, None] * stride_cos_n
47 + dim_range_x[None, None, :] * stride_cos_d
48 )
49 if rotary_interleaved:
50 sin_sim_offset = (
51 token_range[:, None, None] * stride_cos_n
52 + dim_range_y[None, None, :] * stride_cos_d
53 )
54 else:
55 sin_sim_offset = cos_sim_offset
57 state_x = tl.load(
58 state + state_x_offset,
59 mask=(token_range[:, None, None] < num_tokens)
60 & (head_range[None, :, None] < num_heads),
61 other=0.0,
62 )
63 state_y = tl.load(
64 state + state_y_offset,
65 mask=(token_range[:, None, None] < num_tokens)
66 & (head_range[None, :, None] < num_heads),
67 other=0.0,
68 )
70 cos_loaded = tl.load(
71 cos + cos_sim_offset,
72 mask=token_range[:, None, None] < num_tokens,
73 other=0.0,
74 ).to(tl.float32)
75 sin_loaded = tl.load(
76 sin + sin_sim_offset,
77 mask=token_range[:, None, None] < num_tokens,
78 other=0.0,
79 ).to(tl.float32)
81 out_x = state_x * cos_loaded - state_y * sin_loaded
82 out_y = state_x * sin_loaded + state_y * cos_loaded
84 tl.store(
85 state_out + state_x_offset,
86 out_x,
87 mask=(token_range[:, None, None] < num_tokens)
88 & (head_range[None, :, None] < num_heads),
89 )
90 tl.store(
91 state_out + state_y_offset,
92 out_y,
93 mask=(token_range[:, None, None] < num_tokens)
94 & (head_range[None, :, None] < num_heads),
95 )
98@libentry()
99@triton.jit
100def rotary_embedding_siso_kernel(
101 state_out, # [num_tokens, head_num, head_dim]
102 state, # [num_tokens, head_num, head_dim]
103 cos, # [num_tokens, 1, head_dim // 2]
104 sin, # [num_tokens, 1, head_dim // 2]
105 stride_state_n,
106 stride_state_h,
107 stride_state_d,
108 stride_cos_n,
109 stride_cos_d,
110 num_tokens,
111 num_heads,
112 BLOCK_N: tl.constexpr,
113 BLOCK_H: tl.constexpr,
114 BLOCK_D: tl.constexpr,
115 rotary_interleaved: tl.constexpr,
116):
117 token_index = tl.program_id(0)
118 token_range = token_index * BLOCK_N + tl.arange(0, BLOCK_N)
119 head_index = tl.program_id(1)
120 head_range = head_index * BLOCK_H + tl.arange(0, BLOCK_H)
122 if rotary_interleaved:
123 for d in range(0, BLOCK_D // 2):
124 dim_range_x = d * 2
125 dim_range_y = d * 2 + 1
127 rotary_embedding_rw_kernel(
128 state_out,
129 state,
130 cos,
131 sin,
132 stride_state_n,
133 stride_state_h,
134 stride_state_d,
135 stride_cos_n,
136 stride_cos_d,
137 num_tokens,
138 num_heads,
139 token_range,
140 head_range,
141 dim_range_x,
142 dim_range_y,
143 rotary_interleaved,
144 )
145 else:
146 dim_range_x = tl.arange(0, BLOCK_D // 2)
147 dim_range_y = tl.arange(BLOCK_D // 2, BLOCK_D)
148 rotary_embedding_rw_kernel(
149 state_out,
150 state,
151 cos,
152 sin,
153 stride_state_n,
154 stride_state_h,
155 stride_state_d,
156 stride_cos_n,
157 stride_cos_d,
158 num_tokens,
159 num_heads,
160 token_range,
161 head_range,
162 dim_range_x,
163 dim_range_y,
164 rotary_interleaved,
165 )
168def apply_rotary_pos_emb(
169 q,
170 k,
171 cos,
172 sin,
173 position_ids: Optional[torch.IntTensor] = None,
174 rotary_interleaved: bool = False,
175):
176 """
177 Apply rotary position embedding to q and k
179 Args:
180 q: (*, q_heads, head_dim)
181 k: (*, k_heads, head_dim)
182 cos: (max_seq_len, head_dim // 2)
183 sin: (max_seq_len, head_dim // 2)
184 position_ids: (*, ), optional, position ids for each token
185 rotary_interleaved: whether the head_dim is rotated in an interleaved way
187 Returns:
188 q_embed: (*, q_heads, head_dim)
189 k_embed: (*, k_heads, head_dim)
190 """
191 logger.debug("GEMS_ASCEND ROTARY POS EMBEDDING")
192 assert (
193 k.shape[-1] == q.shape[-1]
194 ), f"q and k must have the same last dimension, got {q.shape} and {k.shape}"
195 assert (
196 cos.shape[-1] == sin.shape[-1]
197 ), f"cos and sin must have the same last dimension, got {cos.shape} and {sin.shape}"
198 assert (
199 cos.shape[-1] * 2 == q.shape[-1]
200 ), f"cos/sin dim must be half of q/k dim, got {cos.shape} and {q.shape}"
201 assert cos.stride(-1) == 1, "cos must be contiguous at the last dimension"
202 assert sin.stride(-1) == 1, "sin must be contiguous at the last dimension"
204 q_shape = q.shape
205 k_shape = k.shape
207 assert (
208 q.shape[:-2] == k.shape[:-2]
209 ), f"q and k must have the same length, got {q.shape[:-2]} and {k.shape[:-2]}"
210 if position_ids is None:
211 assert (
212 len(q.shape) == 4
213 ), f"q must have 4 dimensions if position_ids is not provided, got {q.shape}"
214 else:
215 assert (
216 position_ids.shape == q.shape[:-2]
217 ), f"position_ids must have the same length as q, got {position_ids.shape} and {q.shape[:-2]}"
219 position_ids = position_ids.view(-1)
221 q = q.view(-1, q.shape[-2], q.shape[-1])
222 k = k.view(-1, k.shape[-2], k.shape[-1])
224 q_embed = torch.empty_like(q)
225 k_embed = torch.empty_like(k)
227 def torch_rotary_embedding(state_out, state, cos, sin):
228 num_tokens = state.shape[0]
229 num_heads = state.shape[1]
230 head_dim = state.shape[-1]
232 BLOCK_N = 8
233 BLOCK_H = 4
234 grid = (
235 triton.cdiv(num_tokens, BLOCK_N),
236 triton.cdiv(num_heads, BLOCK_H),
237 )
238 with torch_device_fn.device(state_out.device):
239 with flag_gems.use_gems():
240 if position_ids is None:
241 cos = cos[: q_shape[-3], None, :]
242 sin = sin[: q_shape[-3], None, :]
243 else:
244 cos = cos[position_ids, None, :]
245 sin = sin[position_ids, None, :]
247 if rotary_interleaved:
248 cos = torch.repeat_interleave(cos, 2, dim=-1)
249 sin = torch.repeat_interleave(sin, 2, dim=-1)
250 orig_cos = cos
251 orig_sin = sin
252 for _ in range(q_shape[0] - 1):
253 cos = torch.cat((cos, orig_cos), dim=0)
254 sin = torch.cat((sin, orig_sin), dim=0)
255 rotary_embedding_siso_kernel[grid](
256 state_out,
257 state,
258 cos,
259 sin,
260 state.stride(0),
261 state.stride(1),
262 state.stride(2),
263 cos.stride(0),
264 cos.stride(2),
265 num_tokens,
266 num_heads,
267 BLOCK_N=BLOCK_N,
268 BLOCK_H=BLOCK_H,
269 BLOCK_D=head_dim,
270 rotary_interleaved=rotary_interleaved,
271 )
273 torch_rotary_embedding(q_embed, q, cos, sin)
274 torch_rotary_embedding(k_embed, k, cos, sin)
276 q_embed = q_embed.view(q_shape)
277 k_embed = k_embed.view(k_shape)
278 return q_embed, k_embed