Coverage for src/flag_gems/fused/deepseek_v4_attention_fused_q_kv_rmsnorm.py: 53%

43 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1from typing import Tuple 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8 

9 

10@triton.jit(do_not_specialize=["eps"]) 

11def _fused_q_kv_rmsnorm_kernel( 

12 q_ptr, 

13 q_out_ptr, 

14 q_weight_ptr, 

15 q_in_stride, 

16 q_out_stride, 

17 kv_ptr, 

18 kv_out_ptr, 

19 kv_weight_ptr, 

20 kv_in_stride, 

21 kv_out_stride, 

22 eps, 

23 Q_SIZE: tl.constexpr, 

24 KV_SIZE: tl.constexpr, 

25 BLOCK_SIZE: tl.constexpr, 

26): 

27 token_idx = tl.program_id(0).to(tl.int64) 

28 task = tl.program_id(1) 

29 

30 if task == 0: 

31 size = Q_SIZE 

32 row_in = q_ptr + token_idx * q_in_stride 

33 row_out = q_out_ptr + token_idx * q_out_stride 

34 weight_ptr = q_weight_ptr 

35 else: 

36 size = KV_SIZE 

37 row_in = kv_ptr + token_idx * kv_in_stride 

38 row_out = kv_out_ptr + token_idx * kv_out_stride 

39 weight_ptr = kv_weight_ptr 

40 

41 offs = tl.arange(0, BLOCK_SIZE) 

42 mask = offs < size 

43 x = tl.load(row_in + offs, mask=mask, other=0.0).to(tl.float32) 

44 var = tl.sum(x * x, axis=0) / size 

45 rrms = tl.rsqrt(var + eps) 

46 w = tl.load(weight_ptr + offs, mask=mask, other=0.0).to(tl.float32) 

47 y = x * rrms * w 

48 tl.store(row_out + offs, y.to(row_out.dtype.element_ty), mask=mask) 

49 

50 

51def fused_q_kv_rmsnorm( 

52 qr: torch.Tensor, 

53 kv: torch.Tensor, 

54 q_weight: torch.Tensor, 

55 kv_weight: torch.Tensor, 

56 eps: float, 

57) -> Tuple[torch.Tensor, torch.Tensor]: 

58 assert qr.ndim == 2 and kv.ndim == 2 

59 assert qr.shape[0] == kv.shape[0] 

60 assert qr.stride(-1) == 1 and kv.stride(-1) == 1 

61 assert q_weight.is_contiguous() and kv_weight.is_contiguous() 

62 

63 q_size = qr.shape[1] 

64 kv_size = kv.shape[1] 

65 num_tokens = qr.shape[0] 

66 qr_out = torch.empty_like(qr) 

67 kv_out = torch.empty_like(kv) 

68 if num_tokens == 0: 

69 return qr_out, kv_out 

70 

71 block_size = triton.next_power_of_2(max(q_size, kv_size)) 

72 with torch_device_fn.device(qr.device): 

73 _fused_q_kv_rmsnorm_kernel[(num_tokens, 2)]( 

74 qr, 

75 qr_out, 

76 q_weight, 

77 qr.stride(0), 

78 qr_out.stride(0), 

79 kv, 

80 kv_out, 

81 kv_weight, 

82 kv.stride(0), 

83 kv_out.stride(0), 

84 eps, 

85 Q_SIZE=q_size, 

86 KV_SIZE=kv_size, 

87 BLOCK_SIZE=block_size, 

88 ) 

89 return qr_out, kv_out 

90 

91 

92__all__ = ["fused_q_kv_rmsnorm"]