Coverage for src/flag_gems/runtime/backend/_cambricon/fused/fused_add_rms_norm.py: 0%

63 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-23 02:03 +0800

1import logging 

2import math 

3 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import libentry 

9 

10from ..utils import TOTAL_CORE_NUM 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15def get_configs(): 

16 configs = [] 

17 for BLOCK_SIZE in [2048, 1024, 512]: 

18 for M_BLOCK in range(1, 10, 2): 

19 for num_stages in [1, 5]: 

20 configs.append( 

21 triton.Config( 

22 {"M_BLOCK": M_BLOCK, "BLOCK_SIZE": BLOCK_SIZE}, 

23 num_stages=num_stages, 

24 num_warps=1, 

25 ) 

26 ) 

27 return configs 

28 

29 

30@triton.autotune( 

31 configs=get_configs(), 

32 key=["M", "N_COLS"], 

33 restore_value=["x_ptr", "r_ptr"], 

34) 

35@libentry() 

36@triton.jit(do_not_specialize=["eps"]) 

37def fused_add_rms_norm_kernel( 

38 x_ptr, 

39 r_ptr, 

40 w_ptr, 

41 eps, 

42 stride, 

43 M, 

44 N_COLS: tl.constexpr, 

45 BLOCK_SIZE: tl.constexpr, 

46 M_BLOCK: tl.constexpr, 

47): 

48 pid = tl.program_id(0) 

49 pnum = tl.num_programs(axis=0) 

50 M_OUT_BLOCK = tl.cdiv(M, pnum) 

51 

52 lb = pid * M_OUT_BLOCK 

53 ub = tl.minimum((pid + 1) * M_OUT_BLOCK, M) 

54 for m_start in range(lb, ub, M_BLOCK): 

55 m_offset = m_start + tl.arange(0, M_BLOCK) 

56 mx_ptr = x_ptr + stride * m_offset 

57 mr_ptr = r_ptr + stride * m_offset 

58 _mean = tl.zeros([M_BLOCK, BLOCK_SIZE], dtype=tl.float32) 

59 for offset in range(0, N_COLS, BLOCK_SIZE): 

60 cols = offset + tl.arange(0, BLOCK_SIZE) 

61 row_mask = m_offset < ub 

62 col_mask = cols < N_COLS 

63 mask = row_mask[:, None] & col_mask[None, :] 

64 x = tl.load(mx_ptr[:, None] + cols[None, :], mask=mask, other=0.0).to( 

65 tl.float32 

66 ) 

67 r = tl.load(mr_ptr[:, None] + cols[None, :], mask=mask, other=0.0).to( 

68 tl.float32 

69 ) 

70 xpr = x + r 

71 tl.store(mr_ptr[:, None] + cols[None, :], xpr, mask=mask) 

72 _mean += xpr * xpr 

73 

74 # Since `_mean * (1 / N_COLS)` performs better, make this change. 

75 # var = tl.sum(_mean / N_COLS, axis=1) 

76 var = tl.sum(_mean * (1.0 / N_COLS), axis=1) 

77 rrms = 1.0 / tl.sqrt(var + eps) 

78 

79 for offset in range(0, N_COLS, BLOCK_SIZE): 

80 cols = offset + tl.arange(0, BLOCK_SIZE) 

81 row_mask = m_offset < ub 

82 col_mask = cols < N_COLS 

83 mask = row_mask[:, None] & col_mask[None, :] 

84 

85 xpr = tl.load(mr_ptr[:, None] + cols[None, :], mask=mask, other=0.0).to( 

86 tl.float32 

87 ) 

88 w = tl.load(w_ptr + cols, mask=col_mask, other=0.0).to(tl.float32) 

89 y = xpr * rrms[:, None] 

90 y = y * w 

91 y = y.to(x_ptr.dtype.element_ty) 

92 tl.store(mx_ptr[:, None] + cols[None, :], y, mask=mask) 

93 

94 

95def fused_add_rms_norm(x, residual, normalized_shape, weight, eps=1e-5): 

96 """ 

97 This function performs fused residual addition and RMS normalization **in-place**. 

98 Both `x` and `residual` tensors will be modified. Use with caution if these tensors 

99 are reused elsewhere or require gradients. 

100 """ 

101 logger.debug( 

102 "GEMS_CAMBRICON FUSED_ADD_RMS_NORM FORWARD, [input shape]: %s, [residual shape]: %s, [weight shape]: %s", 

103 x.size(), 

104 residual.size(), 

105 weight.size(), 

106 ) 

107 dim = x.ndim - len(normalized_shape) 

108 M = math.prod(x.shape[:dim]) 

109 N = math.prod(normalized_shape) 

110 

111 x = x.contiguous() 

112 residual = residual.contiguous() 

113 weight = weight.contiguous() 

114 

115 with torch_device_fn.device(x.device): 

116 fused_add_rms_norm_kernel[TOTAL_CORE_NUM,]( 

117 x, residual, weight, eps, x.stride(dim - 1), M, N 

118 ) 

119 return x, residual