Coverage for src/flag_gems/runtime/backend/_kunlunxin/fused/fused_add_rms_norm.py: 0%

66 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-16 02:02 +0800

1import builtins 

2import logging 

3import math 

4 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13 

14 

15@libentry() 

16@triton.jit(do_not_specialize=["eps"]) 

17def fused_add_rmsnorm_kernel( 

18 X, # pointer to the input 

19 R, # pointer to the residual 

20 W, # pointer to the weights 

21 x_stride_r, # how much to increase the pointer when moving by 1 row 

22 x_stride_c, # how much to increase the pointer when moving by 1 col 

23 r_stride_r, # how much to increase the pointer when moving by 1 row 

24 r_stride_c, # how much to increase the pointer when moving by 1 col 

25 N, # number of columns in X 

26 eps, # epsilon to avoid division by zero 

27 BLOCK_SIZE: tl.constexpr, 

28): 

29 pid = tle.program_id(0) 

30 X += pid * x_stride_r 

31 R += pid * r_stride_r 

32 

33 mask = tl.arange(0, BLOCK_SIZE) < N 

34 cols = tl.arange(0, BLOCK_SIZE) 

35 x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32) 

36 r = tl.load(R + cols * r_stride_c, mask, other=0.0).to(tl.float32) 

37 

38 x += r 

39 # write back to residual 

40 tl.store(R + cols * r_stride_c, x, mask=mask) 

41 

42 var = tl.sum(x * x / N, axis=0) 

43 rrms = 1 / tl.sqrt(var + eps) 

44 

45 w = tl.load(W + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0) 

46 y = (x * rrms).to(X.dtype.element_ty) * w 

47 # write back to input 

48 tl.store(X + cols * x_stride_c, y, mask=mask) 

49 

50 

51@libentry() 

52@triton.jit(do_not_specialize=["eps"]) 

53def fused_add_rmsnorm_kernel_tile( 

54 X, # pointer to the input 

55 R, # pointer to the residual 

56 W, # pointer to the weight 

57 x_stride_r, # how much to increase the pointer when moving by 1 row 

58 x_stride_c, # how much to increase the pointer when moving by 1 col 

59 r_stride_r, # how much to increase the pointer when moving by 1 row 

60 r_stride_c, # how much to increase the pointer when moving by 1 col 

61 N, # number of columns in X 

62 eps, # epsilon to avoid division by zero 

63 BLOCK_SIZE: tl.constexpr, 

64): 

65 pid = tl.program_id(0) 

66 X += pid * x_stride_r 

67 R += pid * r_stride_r 

68 

69 # var = tl.sum(x * x / N, axis=0) 

70 # rrms = 1 / tl.sqrt(var + eps) 

71 

72 _var_base = tl.zeros([BLOCK_SIZE], dtype=tl.float32) 

73 for off in range(0, N, BLOCK_SIZE): 

74 cols = off + tl.arange(0, BLOCK_SIZE) 

75 mask = cols < N 

76 x = tl.load(X + cols, mask, other=0.0).to(tl.float32) 

77 r = tl.load(R + cols, mask, other=0.0).to(tl.float32) 

78 x += r 

79 _var_base += x * x / N 

80 var = tl.sum(_var_base) 

81 rrms = 1 / tl.sqrt(var + eps) 

82 

83 # w = tl.load(W + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0) 

84 # y = (x * rrms).to(Y.dtype.element_ty) * w 

85 # tl.store(Y + cols * y_stride_c, y, mask=mask) 

86 

87 for off in range(0, N, BLOCK_SIZE): 

88 cols = off + tl.arange(0, BLOCK_SIZE) 

89 mask = cols < N 

90 x = tl.load(X + cols, mask, other=0.0).to(tl.float32) 

91 r = tl.load(R + cols, mask, other=0.0).to(tl.float32) 

92 x += r 

93 w = tl.load(W + cols, mask, other=0.0) 

94 y = (x * rrms).to(X.dtype.element_ty) * w 

95 # write back to residual and input 

96 tl.store(R + cols * r_stride_c, x, mask=mask) 

97 tl.store(X + cols * x_stride_c, y, mask=mask) 

98 

99 

100def fused_add_rms_norm(x, residual, normalized_shape, weight, eps=1e-5): 

101 """ 

102 This function performs fused residual addition and RMS normalization **in-place**. 

103 Both `x` and `residual` tensors will be modified. Use with caution if these tensors 

104 are reused elsewhere or require gradients. 

105 """ 

106 logger.debug("GEMS FUSED_ADD_RMS_NORM FORWARD") 

107 dim = x.ndim - len(normalized_shape) 

108 M = math.prod(x.shape[:dim]) 

109 N = math.prod(normalized_shape) 

110 

111 BLOCK_SIZE = builtins.min( 

112 64 * 128, triton.next_power_of_2(N) 

113 ) # core_num * buffer_size_limit 

114 x = x.contiguous() 

115 residual = residual.contiguous() 

116 weight = weight.contiguous() 

117 

118 with torch_device_fn.device(x.device): 

119 if N > 64 * 128: 

120 fused_add_rmsnorm_kernel_tile[M,]( 

121 x, residual, weight, N, 1, N, 1, N, eps, BLOCK_SIZE 

122 ) 

123 else: 

124 fused_add_rmsnorm_kernel[M,]( 

125 x, residual, weight, N, 1, N, 1, N, eps, BLOCK_SIZE 

126 ) 

127 return x, residual