Coverage for src/flag_gems/fused/deepseek_v4_attention_fused_q_kv_rmsnorm.py: 53%
43 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1from typing import Tuple
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
10@triton.jit(do_not_specialize=["eps"])
11def _fused_q_kv_rmsnorm_kernel(
12 q_ptr,
13 q_out_ptr,
14 q_weight_ptr,
15 q_in_stride,
16 q_out_stride,
17 kv_ptr,
18 kv_out_ptr,
19 kv_weight_ptr,
20 kv_in_stride,
21 kv_out_stride,
22 eps,
23 Q_SIZE: tl.constexpr,
24 KV_SIZE: tl.constexpr,
25 BLOCK_SIZE: tl.constexpr,
26):
27 token_idx = tl.program_id(0).to(tl.int64)
28 task = tl.program_id(1)
30 if task == 0:
31 size = Q_SIZE
32 row_in = q_ptr + token_idx * q_in_stride
33 row_out = q_out_ptr + token_idx * q_out_stride
34 weight_ptr = q_weight_ptr
35 else:
36 size = KV_SIZE
37 row_in = kv_ptr + token_idx * kv_in_stride
38 row_out = kv_out_ptr + token_idx * kv_out_stride
39 weight_ptr = kv_weight_ptr
41 offs = tl.arange(0, BLOCK_SIZE)
42 mask = offs < size
43 x = tl.load(row_in + offs, mask=mask, other=0.0).to(tl.float32)
44 var = tl.sum(x * x, axis=0) / size
45 rrms = tl.rsqrt(var + eps)
46 w = tl.load(weight_ptr + offs, mask=mask, other=0.0).to(tl.float32)
47 y = x * rrms * w
48 tl.store(row_out + offs, y.to(row_out.dtype.element_ty), mask=mask)
51def fused_q_kv_rmsnorm(
52 qr: torch.Tensor,
53 kv: torch.Tensor,
54 q_weight: torch.Tensor,
55 kv_weight: torch.Tensor,
56 eps: float,
57) -> Tuple[torch.Tensor, torch.Tensor]:
58 assert qr.ndim == 2 and kv.ndim == 2
59 assert qr.shape[0] == kv.shape[0]
60 assert qr.stride(-1) == 1 and kv.stride(-1) == 1
61 assert q_weight.is_contiguous() and kv_weight.is_contiguous()
63 q_size = qr.shape[1]
64 kv_size = kv.shape[1]
65 num_tokens = qr.shape[0]
66 qr_out = torch.empty_like(qr)
67 kv_out = torch.empty_like(kv)
68 if num_tokens == 0:
69 return qr_out, kv_out
71 block_size = triton.next_power_of_2(max(q_size, kv_size))
72 with torch_device_fn.device(qr.device):
73 _fused_q_kv_rmsnorm_kernel[(num_tokens, 2)](
74 qr,
75 qr_out,
76 q_weight,
77 qr.stride(0),
78 qr_out.stride(0),
79 kv,
80 kv_out,
81 kv_weight,
82 kv.stride(0),
83 kv_out.stride(0),
84 eps,
85 Q_SIZE=q_size,
86 KV_SIZE=kv_size,
87 BLOCK_SIZE=block_size,
88 )
89 return qr_out, kv_out
92__all__ = ["fused_q_kv_rmsnorm"]