Coverage for src/flag_gems/ops/rms_norm.py: 37%

119 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-18 02:36 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15@libentry() 

16@triton.jit(do_not_specialize=["eps"]) 

17def rms_norm_kernel( 

18 out_ptr, # pointer to the output 

19 INV_RMS, # pointer to inverse rms 

20 in_ptr, # pointer to the input 

21 w_ptr, # pointer to the weights 

22 y_stride_r, 

23 y_stride_c, 

24 x_stride_r, # how much to increase the pointer when moving by 1 row 

25 x_stride_c, # how much to increase the pointer when moving by 1 col 

26 N, # number of columns in X 

27 eps, # epsilon to avoid division by zero 

28 BLOCK_SIZE: tl.constexpr, 

29): 

30 if tl.constexpr(in_ptr.dtype.element_ty == tl.float16) or tl.constexpr( 

31 in_ptr.dtype.element_ty == tl.bfloat16 

32 ): 

33 cdtype = tl.float32 

34 else: 

35 cdtype = in_ptr.dtype.element_ty 

36 

37 pid = tl.program_id(0) 

38 out_ptr += pid * y_stride_r 

39 in_ptr += pid * x_stride_r 

40 

41 mask = tl.arange(0, BLOCK_SIZE) < N 

42 cols = tl.arange(0, BLOCK_SIZE) 

43 x = tl.load(in_ptr + cols * x_stride_c, mask, other=0.0).to(cdtype) 

44 

45 var = tl.sum(x * x, axis=0) / N 

46 rrms = 1 / tl.sqrt(var + eps) 

47 

48 w = tl.load(w_ptr + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0) 

49 y = (x * rrms * w).to(cdtype) 

50 tl.store(out_ptr + cols * y_stride_c, y, mask=mask) 

51 tl.store(INV_RMS + pid, rrms) 

52 

53 

54@libentry() 

55@triton.jit(do_not_specialize=["eps"]) 

56def rms_norm_grad_dx_kernel( 

57 X, # pointer to the input 

58 DY, 

59 INV_RMS, # pointer to inverse rms 

60 DX, # pointer to the output 

61 W, # pointer to the weights 

62 dx_stride_r, 

63 dx_stride_c, 

64 x_stride_r, # how much to increase the pointer when moving by 1 row 

65 x_stride_c, # how much to increase the pointer when moving by 1 col 

66 N, # number of columns in X 

67 eps, # epsilon to avoid division by zero 

68 BLOCK_SIZE: tl.constexpr, 

69): 

70 pid = tle.program_id(0) 

71 DX += pid * dx_stride_r 

72 X += pid * x_stride_r 

73 DY += pid * x_stride_r 

74 INV_RMS += pid 

75 

76 mask = tl.arange(0, BLOCK_SIZE) < N 

77 cols = tl.arange(0, BLOCK_SIZE) 

78 x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32) 

79 inv_rms = tl.load(INV_RMS).to(tl.float32) 

80 dy = tl.load(DY + cols * x_stride_c, mask, other=0.0).to(tl.float32) 

81 w = tl.load(W + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0) 

82 

83 dy = dy * w 

84 

85 normalized_buf = x * inv_rms 

86 row_sum_stats = tl.sum(normalized_buf * dy, axis=0) 

87 

88 norm_val = normalized_buf / N 

89 dx = (dy - norm_val * row_sum_stats) * inv_rms 

90 

91 tl.store(DX + cols * dx_stride_c, dx, mask=mask) 

92 

93 

94@libentry() 

95@triton.jit 

96def rms_norm_grad_dw_kernel( 

97 X, # pointer to the input 

98 DY, 

99 INV_RMS, # pointer to inverse rms 

100 DW, # pointer to the output 

101 dx_stride_r, 

102 dx_stride_c, 

103 x_stride_r, # how much to increase the pointer when moving by 1 row 

104 x_stride_c, # how much to increase the pointer when moving by 1 col 

105 M, # number of rows in X 

106 N, # number of columns in X 

107 ROW_BLOCK_SIZE: tl.constexpr, 

108 COL_BLOCK_SIZE: tl.constexpr, 

109): 

110 row_pid = tl.program_id(0) 

111 col_pid = tl.program_id(1) 

112 

113 row_start = row_pid * ROW_BLOCK_SIZE 

114 col_start = col_pid * COL_BLOCK_SIZE 

115 

116 offset = row_start * x_stride_r + col_start * x_stride_c 

117 X += offset 

118 DY += offset 

119 INV_RMS += row_start 

120 

121 rows = tl.arange(0, ROW_BLOCK_SIZE) 

122 cols = tl.arange(0, COL_BLOCK_SIZE) 

123 

124 row_mask = (row_start + rows) < M 

125 col_mask = (col_start + cols) < N 

126 

127 x = tl.load( 

128 X + rows[:, None] * x_stride_r + cols[None, :] * x_stride_c, 

129 row_mask[:, None] & col_mask[None, :], 

130 other=0.0, 

131 ).to(tl.float32) 

132 inv_rms = tl.load(INV_RMS + rows, row_mask, other=0.0).to(tl.float32) 

133 dy = tl.load( 

134 DY + rows[:, None] * x_stride_r + cols[None, :] * x_stride_c, 

135 row_mask[:, None] & col_mask[None, :], 

136 other=0.0, 

137 ).to(tl.float32) 

138 

139 d_weight = x * dy * inv_rms[:, None] 

140 # Sum over rows (axis=0) - masked rows are 0 (from other=0.0 in load), so sum is correct 

141 # The mask ensures invalid rows contribute 0 to the sum 

142 partial_dweight_sum = tl.sum(d_weight, axis=0) 

143 

144 tl.store( 

145 DW + row_pid * N + col_start + cols, 

146 partial_dweight_sum, 

147 mask=col_mask, 

148 ) 

149 

150 

151def rms_norm_forward(x, normalized_shape, weight, eps=1e-5): 

152 logger.debug("GEMS RMS_NORM FORWARD") 

153 dim = x.ndim - len(normalized_shape) 

154 M = math.prod(x.shape[:dim]) 

155 N = math.prod(normalized_shape) 

156 

157 BLOCK_SIZE = triton.next_power_of_2(N) 

158 x = x.contiguous() 

159 weight = weight.contiguous() 

160 y = torch.empty_like(x) 

161 inv_rms = torch.empty((M,), device=x.device, dtype=torch.float32) 

162 

163 with torch_device_fn.device(x.device): 

164 rms_norm_kernel[M,](y, inv_rms, x, weight, N, 1, N, 1, N, eps, BLOCK_SIZE) 

165 

166 return y, inv_rms 

167 

168 

169def rms_norm_backward(dy, x, inv_rms, normalized_shape, weight, eps=1e-5): 

170 logger.debug("GEMS RMS_NORM BACKWARD") 

171 dim = x.ndim - len(normalized_shape) 

172 M = math.prod(x.shape[:dim]) 

173 N = math.prod(normalized_shape) 

174 

175 BLOCK_SIZE = triton.next_power_of_2(N) 

176 x = x.contiguous() 

177 dy = dy.contiguous() 

178 weight = weight.contiguous() 

179 dx = torch.empty_like(x) 

180 

181 with torch_device_fn.device(x.device): 

182 rms_norm_grad_dx_kernel[M,]( 

183 x, dy, inv_rms, dx, weight, N, 1, N, 1, N, eps, BLOCK_SIZE 

184 ) 

185 

186 ROW_BLOCK_SIZE = 16 

187 COL_BLOCK_SIZE = 256 

188 row_block_num = triton.cdiv(M, ROW_BLOCK_SIZE) 

189 col_block_num = triton.cdiv(N, COL_BLOCK_SIZE) 

190 

191 partial_buffer = torch.empty( 

192 (row_block_num, N), dtype=torch.float32, device=x.device 

193 ) 

194 

195 with torch_device_fn.device(x.device): 

196 rms_norm_grad_dw_kernel[row_block_num, col_block_num]( 

197 x, 

198 dy, 

199 inv_rms, 

200 partial_buffer, 

201 N, 

202 1, 

203 N, 

204 1, 

205 M, 

206 N, 

207 ROW_BLOCK_SIZE, 

208 COL_BLOCK_SIZE, 

209 ) 

210 dw = ( 

211 torch.sum(partial_buffer, dim=0, dtype=torch.float32) 

212 .to(x.dtype) 

213 .reshape(-1) 

214 ) 

215 

216 return dx, dw 

217 

218 

219class RmsNorm(torch.autograd.Function): 

220 @staticmethod 

221 def forward(ctx, x, normalized_shape, weight, eps=1e-5): 

222 y, inv_rms = rms_norm_forward(x, normalized_shape, weight, eps) 

223 ctx.save_for_backward(x, inv_rms, weight) 

224 ctx.normalized_shape = normalized_shape 

225 ctx.eps = eps 

226 return y 

227 

228 @staticmethod 

229 def backward(ctx, dy): 

230 x, inv_rms, weight = ctx.saved_tensors 

231 normalized_shape = ctx.normalized_shape 

232 eps = ctx.eps 

233 

234 dx, dw = rms_norm_backward(dy, x, inv_rms, normalized_shape, weight, eps) 

235 return dx, None, dw, None 

236 

237 

238def rms_norm(x, normalized_shape, weight, eps=1e-5): 

239 return RmsNorm.apply(x, normalized_shape, weight, eps)