Coverage for src/flag_gems/runtime/backend/_kunlunxin/fused/rwkv_ka_fusion.py: 0%
37 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def rwkv_ka_fusion_kernel(
8 k_ptr,
9 kk_ptr,
10 a_ptr,
11 ka_ptr,
12 o_k_ptr,
13 o_kk_ptr,
14 o_kka_ptr,
15 T,
16 C,
17 H,
18 N,
19 N_size: tl.constexpr,
20 block_size: tl.constexpr,
21):
22 pid = tl.program_id(axis=0)
23 k_start = pid * block_size
25 for i in range(0, H):
26 offs = k_start + i * N + tl.arange(0, N_size)
27 k = tl.load(k_ptr + offs, mask=offs < T * C, other=0.0)
28 a = tl.load(a_ptr + offs, mask=offs < T * C, other=0.0)
30 c_offs = i * N + tl.arange(0, N_size)
31 ka = tl.load(ka_ptr + c_offs, mask=c_offs < C, other=0.0)
32 kk = tl.load(kk_ptr + c_offs, mask=c_offs < C, other=0.0)
34 kt = k * kk
35 kt2 = kt * kt
36 norm_kt2 = tl.sum(kt2.to(tl.float32))
37 norm_kt = tl.sqrt(norm_kt2 + 1e-12)
38 okk = kt / norm_kt
39 tl.store(o_kk_ptr + offs, okk, mask=offs < T * C)
41 ok = k * (1 + (a.to(tl.float32) - 1) * ka)
42 okka = okk * a
43 tl.store(o_k_ptr + offs, ok, mask=offs < T * C)
44 tl.store(o_kka_ptr + offs, okka, mask=offs < T * C)
47def rwkv_ka_fusion(
48 k: torch.Tensor, kk: torch.Tensor, a: torch.Tensor, ka: torch.Tensor, H: int, N: int
49):
50 if k.dim() == 1:
51 T = 1
52 C = k.shape[0]
53 else:
54 T, C = k.shape
56 o_k = torch.empty_like(k)
57 o_kk = torch.empty_like(k)
58 o_kka = torch.empty_like(k)
60 BLOCK_SIZE = 1 * C
61 grid = lambda meta: (triton.cdiv(T * C, BLOCK_SIZE),)
62 N_size = triton.next_power_of_2(N)
63 rwkv_ka_fusion_kernel[grid](
64 k, kk, a, ka, o_k, o_kk, o_kka, T, C, H, N, N_size, BLOCK_SIZE
65 )
67 return o_k, o_kk, o_kka