Coverage for src/flag_gems/runtime/backend/_tsingmicro/ops/rms_norm.py: 0%

175 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-15 02:11 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13MAX_NRAM_C_FORWARD = 16384 * 2 

14 

15 

16@libentry() 

17@triton.jit(do_not_specialize=["eps"]) 

18def rms_norm_kernel( 

19 Y, # pointer to the output 

20 INV_RMS, # pointer to inverse rms 

21 X, # pointer to the input 

22 W, # pointer to the weights 

23 y_stride_r, 

24 y_stride_c, 

25 x_stride_r, # how much to increase the pointer when moving by 1 row 

26 x_stride_c, # how much to increase the pointer when moving by 1 col 

27 N, # number of columns in X 

28 eps, # epsilon to avoid division by zero 

29 BLOCK_SIZE: tl.constexpr, 

30): 

31 pid = tl.program_id(0) 

32 Y += pid * y_stride_r 

33 X += pid * x_stride_r 

34 

35 mask = tl.arange(0, BLOCK_SIZE) < N 

36 cols = tl.arange(0, BLOCK_SIZE) 

37 x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32) 

38 

39 var = tl.sum(x * x, axis=0) / N 

40 rrms = 1 / tl.sqrt(var + eps) 

41 

42 w = tl.load(W + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0) 

43 y = (x * rrms).to(Y.dtype.element_ty) * w 

44 tl.store(Y + cols * y_stride_c, y, mask=mask) 

45 tl.store(INV_RMS + pid, rrms) 

46 

47 

48@libentry() 

49@triton.autotune( 

50 configs=runtime.get_tuned_config("common_reduce_ops"), 

51 key=["N"], 

52) 

53@triton.jit(do_not_specialize=["eps"]) 

54def rms_norm_kernel_C_split( 

55 Y, # pointer to the output 

56 INV_RMS, # pointer to inverse rms 

57 X, # pointer to the input 

58 W, # pointer to the weights 

59 y_stride_r, 

60 y_stride_c, 

61 x_stride_r, # how much to increase the pointer when moving by 1 row 

62 x_stride_c, # how much to increase the pointer when moving by 1 col 

63 N, # number of columns in X 

64 eps, # epsilon to avoid division by zero 

65 BLOCK_SIZE: tl.constexpr, 

66): 

67 pid = tl.program_id(0) 

68 Y += pid * y_stride_r 

69 X += pid * x_stride_r 

70 

71 var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) 

72 for m_idx in range(0, N, BLOCK_SIZE): 

73 cols = m_idx + tl.arange(0, BLOCK_SIZE) 

74 mask = cols < N 

75 x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32) 

76 var += x * x 

77 

78 var = tl.sum(var, axis=0) / N 

79 rrms = 1 / tl.sqrt(var + eps) 

80 

81 for m_idx in range(0, N, BLOCK_SIZE): 

82 cols = m_idx + tl.arange(0, BLOCK_SIZE) 

83 mask = cols < N 

84 w = tl.load(W + cols, mask=mask, other=0.0) 

85 x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32) 

86 y = (x * rrms).to(Y.dtype.element_ty) * w 

87 tl.store(Y + cols * y_stride_c, y, mask=mask) 

88 tl.store(INV_RMS + pid, rrms) 

89 

90 

91@libentry() 

92@triton.jit(do_not_specialize=["eps"]) 

93def rms_norm_grad_dx_kernel( 

94 X, # pointer to the input 

95 DY, 

96 INV_RMS, # pointer to inverse rms 

97 DX, # pointer to the output 

98 W, # pointer to the weights 

99 dx_stride_r, 

100 dx_stride_c, 

101 x_stride_r, # how much to increase the pointer when moving by 1 row 

102 x_stride_c, # how much to increase the pointer when moving by 1 col 

103 N, # number of columns in X 

104 eps, # epsilon to avoid division by zero 

105 BLOCK_SIZE: tl.constexpr, 

106): 

107 pid = tl.program_id(0) 

108 DX += pid * dx_stride_r 

109 X += pid * x_stride_r 

110 DY += pid * x_stride_r 

111 INV_RMS += pid 

112 

113 mask = tl.arange(0, BLOCK_SIZE) < N 

114 cols = tl.arange(0, BLOCK_SIZE) 

115 x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32) 

116 inv_rms = tl.load(INV_RMS).to(tl.float32) 

117 dy = tl.load(DY + cols * x_stride_c, mask, other=0.0).to(tl.float32) 

118 w = tl.load(W + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0) 

119 

120 dy = dy * w 

121 

122 normalized_buf = x * inv_rms 

123 row_sum_stats = tl.sum(normalized_buf * dy, axis=0) 

124 

125 norm_val = normalized_buf / N 

126 dx = (dy - norm_val * row_sum_stats) * inv_rms 

127 

128 tl.store(DX + cols * dx_stride_c, dx, mask=mask) 

129 

130 

131@libentry() 

132@triton.autotune( 

133 configs=runtime.get_tuned_config("common_reduce_ops"), 

134 key=["N"], 

135) 

136@triton.jit(do_not_specialize=["eps"]) 

137def rms_norm_grad_dx_kernel_C_split( 

138 X, # pointer to the input 

139 DY, 

140 INV_RMS, # pointer to inverse rms 

141 DX, # pointer to the output 

142 W, # pointer to the weights 

143 dx_stride_r, 

144 dx_stride_c, 

145 x_stride_r, # how much to increase the pointer when moving by 1 row 

146 x_stride_c, # how much to increase the pointer when moving by 1 col 

147 N, # number of columns in X 

148 eps, # epsilon to avoid division by zero 

149 BLOCK_SIZE: tl.constexpr, 

150): 

151 pid = tl.program_id(0) 

152 DX += pid * dx_stride_r 

153 X += pid * x_stride_r 

154 DY += pid * x_stride_r 

155 INV_RMS += pid 

156 inv_rms = tl.load(INV_RMS).to(tl.float32) 

157 

158 acc = tl.zeros([BLOCK_SIZE], dtype=tl.float32) 

159 for m_idx in range(0, N, BLOCK_SIZE): 

160 cols = m_idx + tl.arange(0, BLOCK_SIZE) 

161 mask = cols < N 

162 x = tl.load(X + cols * x_stride_c, mask=mask, other=0.0).to(tl.float32) 

163 inv_rms = tl.load(INV_RMS).to(tl.float32) 

164 dy = tl.load(DY + cols * x_stride_c, mask=mask, other=0.0).to(tl.float32) 

165 w = tl.load(W + cols, mask=mask, other=0.0) 

166 dy = dy * w 

167 normalized = x * inv_rms 

168 acc += normalized * dy 

169 

170 row_sum_stats = tl.sum(acc, axis=0) 

171 

172 for m_idx in range(0, N, BLOCK_SIZE): 

173 cols = m_idx + tl.arange(0, BLOCK_SIZE) 

174 mask = cols < N 

175 x = tl.load(X + cols * x_stride_c, mask=mask, other=0.0).to(tl.float32) 

176 inv_rms = tl.load(INV_RMS).to(tl.float32) 

177 dy = tl.load(DY + cols * x_stride_c, mask=mask, other=0.0).to(tl.float32) 

178 w = tl.load(W + cols, mask=mask, other=0.0) 

179 dy = dy * w 

180 normalized = x * inv_rms 

181 norm_val = normalized / N 

182 dx = (dy - norm_val * row_sum_stats) * inv_rms 

183 tl.store(DX + cols * dx_stride_c, dx, mask=mask) 

184 

185 

186@libentry() 

187@triton.jit 

188def rms_norm_grad_dw_kernel( 

189 X, # pointer to the input 

190 DY, 

191 INV_RMS, # pointer to inverse rms 

192 DW, # pointer to the output 

193 dx_stride_r, 

194 dx_stride_c, 

195 x_stride_r, # how much to increase the pointer when moving by 1 row 

196 x_stride_c, # how much to increase the pointer when moving by 1 col 

197 M, # number of rows in X 

198 N, # number of columns in X 

199 ROW_BLOCK_SIZE: tl.constexpr, 

200 COL_BLOCK_SIZE: tl.constexpr, 

201): 

202 row_pid = tl.program_id(0) 

203 col_pid = tl.program_id(1) 

204 

205 row_start = row_pid * ROW_BLOCK_SIZE 

206 col_start = col_pid * COL_BLOCK_SIZE 

207 

208 offset = row_start * x_stride_r + col_start * x_stride_c 

209 X += offset 

210 DY += offset 

211 INV_RMS += row_start 

212 

213 rows = tl.arange(0, ROW_BLOCK_SIZE) 

214 cols = tl.arange(0, COL_BLOCK_SIZE) 

215 

216 row_mask = (row_start + rows) < M 

217 col_mask = (col_start + cols) < N 

218 

219 x = tl.load( 

220 X + rows[:, None] * x_stride_r + cols[None, :] * x_stride_c, 

221 row_mask[:, None] & col_mask[None, :], 

222 other=0.0, 

223 ).to(tl.float32) 

224 inv_rms = tl.load(INV_RMS + rows, row_mask, other=0.0).to(tl.float32) 

225 dy = tl.load( 

226 DY + rows[:, None] * x_stride_r + cols[None, :] * x_stride_c, 

227 row_mask[:, None] & col_mask[None, :], 

228 other=0.0, 

229 ).to(tl.float32) 

230 

231 d_weight = x * dy * inv_rms[:, None] 

232 partial_dweight_sum = tl.sum(d_weight, axis=0) 

233 

234 tl.store( 

235 DW + row_pid * N + col_start + cols, 

236 partial_dweight_sum, 

237 mask=col_mask, 

238 ) 

239 

240 

241class RmsNorm(torch.autograd.Function): 

242 @staticmethod 

243 def forward(ctx, x, normalized_shape, weight, eps=1e-5): 

244 logger.debug("GEMS_TSINGMICRO RMSNORM FORWARD") 

245 dim = x.ndim - len(normalized_shape) 

246 M = math.prod(x.shape[:dim]) 

247 N = math.prod(normalized_shape) 

248 

249 BLOCK_SIZE = N # triton.next_power_of_2(N) 

250 x = x.contiguous() 

251 weight = weight.contiguous() 

252 y = torch.empty_like(x) 

253 inv_rms = torch.empty((M,), device=x.device, dtype=torch.float32) 

254 

255 with torch_device_fn.device(x.device): 

256 if BLOCK_SIZE <= MAX_NRAM_C_FORWARD: 

257 logger.debug("GEMS_TSINGMICRO RMSNORM FORWARD NOT USING C SPLIT") 

258 rms_norm_kernel[M,]( 

259 y, inv_rms, x, weight, N, 1, N, 1, N, eps, BLOCK_SIZE 

260 ) 

261 else: 

262 logger.debug("GEMS_TSINGMICRO RMSNORM FORWARD USING C SPLIT") 

263 rms_norm_kernel_C_split[M,](y, inv_rms, x, weight, N, 1, N, 1, N, eps) 

264 

265 ctx.save_for_backward(x, inv_rms, weight) 

266 ctx.normalized_shape = normalized_shape 

267 ctx.eps = eps 

268 return y 

269 

270 @staticmethod 

271 def backward(ctx, dy): 

272 logger.debug("GEMS_TSINGMICRO RMSNORM BACKWARD") 

273 x, inv_rms, weight = ctx.saved_tensors 

274 normalized_shape = ctx.normalized_shape 

275 eps = ctx.eps 

276 

277 dim = x.ndim - len(normalized_shape) 

278 M = math.prod(x.shape[:dim]) 

279 N = math.prod(normalized_shape) 

280 

281 # BLOCK_SIZE = triton.next_power_of_2(N) 

282 BLOCK_SIZE = N 

283 x = x.contiguous() 

284 weight = weight.contiguous() 

285 dx = torch.empty_like(x) 

286 

287 with torch_device_fn.device(x.device): 

288 if BLOCK_SIZE <= MAX_NRAM_C_FORWARD: 

289 logger.debug("GEMS_TSINGMICRO RMSNORM BACKWARD NOT USING C SPLIT") 

290 rms_norm_grad_dx_kernel[M,]( 

291 x, dy, inv_rms, dx, weight, N, 1, N, 1, N, eps, BLOCK_SIZE 

292 ) 

293 else: 

294 logger.debug("GEMS_TSINGMICRO RMSNORM BACKWARD USING C SPLIT") 

295 rms_norm_grad_dx_kernel_C_split[M,]( 

296 x, dy, inv_rms, dx, weight, N, 1, N, 1, N, eps 

297 ) 

298 

299 ROW_BLOCK_SIZE = 16 

300 COL_BLOCK_SIZE = 256 

301 row_block_num = triton.cdiv(M, ROW_BLOCK_SIZE) 

302 col_block_num = triton.cdiv(N, COL_BLOCK_SIZE) 

303 

304 partial_buffer = torch.empty( 

305 (row_block_num, N), dtype=torch.float32, device=x.device 

306 ) 

307 

308 with torch_device_fn.device(x.device): 

309 rms_norm_grad_dw_kernel[row_block_num, col_block_num]( 

310 x, 

311 dy, 

312 inv_rms, 

313 partial_buffer, 

314 N, 

315 1, 

316 N, 

317 1, 

318 M, 

319 N, 

320 ROW_BLOCK_SIZE, 

321 COL_BLOCK_SIZE, 

322 ) 

323 dw = torch.sum(partial_buffer, dim=0, dtype=x.dtype).reshape(-1) 

324 

325 return dx, None, dw, None 

326 

327 

328def rms_norm(x, normalized_shape, weight, eps=1e-5): 

329 return RmsNorm.apply(x, normalized_shape, weight, eps)