Coverage for src/flag_gems/runtime/backend/_sunrise/fused/fused_add_rms_norm.py: 0%

72 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1import logging 

2import math 

3 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import libentry 

9from flag_gems.utils import triton_lang_extension as ext 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14@libentry() 

15@triton.jit(do_not_specialize=["eps"]) 

16def fused_add_rms_norm_kernel( 

17 input_ptr, # pointer to the input 

18 residual_ptr, # pointer to the residual 

19 w_ptr, # pointer to the weights 

20 in_stride_r, # how much to increase the pointer when moving by 1 row 

21 in_stride_c, # how much to increase the pointer when moving by 1 col 

22 r_stride_r, # how much to increase the pointer when moving by 1 row 

23 r_stride_c, # how much to increase the pointer when moving by 1 col 

24 N, # number of columns in in_ptr 

25 eps, # epsilon to avoid division by zero 

26 BLOCK_SIZE: tl.constexpr, 

27): 

28 if tl.constexpr(input_ptr.dtype.element_ty == tl.float16) or tl.constexpr( 

29 input_ptr.dtype.element_ty == tl.bfloat16 

30 ): 

31 cdtype = tl.float32 

32 else: 

33 cdtype = input_ptr.dtype.element_ty 

34 

35 pid = ext.program_id(0) 

36 input_ptr += pid * in_stride_r 

37 residual_ptr += pid * r_stride_r 

38 

39 mask = tl.arange(0, BLOCK_SIZE) < N 

40 cols = tl.arange(0, BLOCK_SIZE) 

41 x = tl.load(input_ptr + cols * in_stride_c, mask, other=0.0).to(cdtype) 

42 r = tl.load(residual_ptr + cols * r_stride_c, mask, other=0.0).to(cdtype) 

43 

44 x += r 

45 # write back to residual 

46 tl.store(residual_ptr + cols * r_stride_c, x, mask=mask) 

47 

48 var = tl.sum(x * x / N, axis=0) 

49 rrms = 1 / tl.sqrt(var + eps) 

50 

51 w = tl.load(w_ptr + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0) 

52 y = (x * rrms * w).to(cdtype) 

53 # write back to input 

54 tl.store(input_ptr + cols * in_stride_c, y, mask=mask) 

55 

56 

57@libentry() 

58@triton.jit(do_not_specialize=["eps"]) 

59def fused_add_rms_norm_c_split_kernel( 

60 input_ptr, # pointer to the input 

61 residual_ptr, # pointer to the residual 

62 w_ptr, # pointer to the weights 

63 in_stride_r, # how much to increase the pointer when moving by 1 row 

64 in_stride_c, # how much to increase the pointer when moving by 1 col 

65 r_stride_r, # how much to increase the pointer when moving by 1 row 

66 r_stride_c, # how much to increase the pointer when moving by 1 col 

67 N, # number of columns in in_ptr 

68 eps, # epsilon to avoid division by zero 

69 BLOCK_SIZE: tl.constexpr, 

70): 

71 if tl.constexpr(input_ptr.dtype.element_ty == tl.float16) or tl.constexpr( 

72 input_ptr.dtype.element_ty == tl.bfloat16 

73 ): 

74 cdtype = tl.float32 

75 else: 

76 cdtype = input_ptr.dtype.element_ty 

77 

78 pid = ext.program_id(0) 

79 input_ptr += pid * in_stride_r 

80 residual_ptr += pid * r_stride_r 

81 

82 _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) 

83 for off in range(0, N, BLOCK_SIZE): 

84 cols = off + tl.arange(0, BLOCK_SIZE) 

85 mask = cols < N 

86 x = tl.load(input_ptr + cols, mask, other=0.0).to(cdtype) 

87 r = tl.load(residual_ptr + cols, mask, other=0.0).to(cdtype) 

88 x += r 

89 _var += x * x / N 

90 

91 var = tl.sum(_var) 

92 rrms = 1 / tl.sqrt(var + eps) 

93 

94 for off in range(0, N, BLOCK_SIZE): 

95 cols = off + tl.arange(0, BLOCK_SIZE) 

96 mask = cols < N 

97 x = tl.load(input_ptr + cols, mask, other=0.0).to(cdtype) 

98 r = tl.load(residual_ptr + cols, mask, other=0.0).to(cdtype) 

99 x += r 

100 tl.store(residual_ptr + cols, x, mask=mask) 

101 w = tl.load(w_ptr + cols, mask=mask, other=0.0) 

102 y = (x * rrms * w).to(cdtype) 

103 tl.store(input_ptr + cols * in_stride_c, y, mask=mask) 

104 

105 

106def fused_add_rms_norm(x, residual, normalized_shape, weight, eps=1e-5): 

107 """ 

108 This function performs fused residual addition and RMS normalization **in-place**. 

109 Both `x` and `residual` tensors will be modified. Use with caution if these tensors 

110 are reused elsewhere or require gradients. 

111 """ 

112 logger.debug( 

113 "GEMS FUSED_ADD_RMS_NORM FORWARD, [input shape]: %s, [residual shape]: %s, [weight shape]: %s", 

114 x.size(), 

115 residual.size(), 

116 weight.size(), 

117 ) 

118 dim = x.ndim - len(normalized_shape) 

119 M = math.prod(x.shape[:dim]) 

120 N = math.prod(normalized_shape) 

121 

122 BLOCK_SIZE = triton.next_power_of_2(N) 

123 x = x.contiguous() 

124 residual = residual.contiguous() 

125 weight = weight.contiguous() 

126 

127 with torch_device_fn.device(x.device): 

128 if BLOCK_SIZE <= 1024: 

129 fused_add_rms_norm_kernel[M,]( 

130 x, residual, weight, N, 1, N, 1, N, eps, BLOCK_SIZE 

131 ) 

132 else: 

133 BLOCK_SIZE = 1024 

134 fused_add_rms_norm_c_split_kernel[M,]( 

135 x, 

136 residual, 

137 weight, 

138 N, 

139 1, 

140 N, 

141 1, 

142 N, 

143 eps, 

144 BLOCK_SIZE, 

145 num_warps=8, 

146 ) 

147 return x, residual