Coverage for src/flag_gems/runtime/backend/_cambricon/ops/rms_norm.py: 0%

199 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-10 02:30 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10 

11from ..utils import MAX_GRID_SIZE_X, cfggen_reduce_op 

12 

13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

14MAX_NRAM_C_FORWARD = 16384 * 2 

15 

16 

17def rms_norm_forward(x, normalized_shape, weight, eps=1e-5): 

18 logger.debug("GEMS_CAMBRICON RMSNORM FORWARD") 

19 dim = x.ndim - len(normalized_shape) 

20 M = math.prod(x.shape[:dim]) 

21 N = math.prod(normalized_shape) 

22 

23 BLOCK_SIZE = N # triton.next_power_of_2(N) 

24 x = x.contiguous() 

25 weight = weight.contiguous() 

26 y = torch.empty_like(x) 

27 inv_rms = torch.empty((M,), device=x.device, dtype=torch.float32) 

28 grid = (min(M, MAX_GRID_SIZE_X // 4),) 

29 with torch_device_fn.device(x.device): 

30 if BLOCK_SIZE <= MAX_NRAM_C_FORWARD: 

31 logger.debug("GEMS_CAMBRICON RMSNORM FORWARD NOT USING C SPLIT") 

32 rms_norm_kernel[grid]( 

33 y, inv_rms, x, weight, N, 1, N, 1, N, eps, M, BLOCK_SIZE 

34 ) 

35 else: 

36 logger.debug("GEMS_CAMBRICON RMSNORM FORWARD USING C SPLIT") 

37 rms_norm_kernel_C_split[grid](y, inv_rms, x, weight, N, 1, N, 1, N, eps, M) 

38 return y, inv_rms 

39 

40 

41def rms_norm_backward(dy, x, inv_rms, normalized_shape, weight, eps=1e-5): 

42 logger.debug("GEMS_CAMBRICON RMSNORM BACKWARD") 

43 dim = x.ndim - len(normalized_shape) 

44 M = math.prod(x.shape[:dim]) 

45 N = math.prod(normalized_shape) 

46 

47 # BLOCK_SIZE = triton.next_power_of_2(N) 

48 BLOCK_SIZE = N 

49 x = x.contiguous() 

50 weight = weight.contiguous() 

51 dx = torch.empty_like(x) 

52 grid = (min(M, MAX_GRID_SIZE_X // 4),) 

53 with torch_device_fn.device(x.device): 

54 if BLOCK_SIZE <= MAX_NRAM_C_FORWARD: 

55 logger.debug("GEMS_CAMBRICON RMSNORM BACKWARD NOT USING C SPLIT") 

56 rms_norm_grad_dx_kernel[grid]( 

57 x, dy, inv_rms, dx, weight, N, 1, N, 1, N, eps, M, BLOCK_SIZE 

58 ) 

59 else: 

60 logger.debug("GEMS_CAMBRICON RMSNORM BACKWARD USING C SPLIT") 

61 rms_norm_grad_dx_kernel_C_split[grid]( 

62 x, dy, inv_rms, dx, weight, N, 1, N, 1, N, eps, M 

63 ) 

64 

65 ROW_BLOCK_SIZE = 16 

66 COL_BLOCK_SIZE = 256 

67 row_block_num = triton.cdiv(M, ROW_BLOCK_SIZE) 

68 col_block_num = triton.cdiv(N, COL_BLOCK_SIZE) 

69 

70 partial_buffer = torch.empty( 

71 (row_block_num, N), dtype=torch.float32, device=x.device 

72 ) 

73 

74 with torch_device_fn.device(x.device): 

75 rms_norm_grad_dw_kernel[row_block_num, col_block_num]( 

76 x, 

77 dy, 

78 inv_rms, 

79 partial_buffer, 

80 N, 

81 1, 

82 N, 

83 1, 

84 M, 

85 N, 

86 ROW_BLOCK_SIZE, 

87 COL_BLOCK_SIZE, 

88 ) 

89 dw = torch.sum(partial_buffer, dim=0, dtype=x.dtype).reshape(-1) 

90 

91 return dx, dw 

92 

93 

94@libentry() 

95@triton.jit(do_not_specialize=["eps"]) 

96def rms_norm_kernel( 

97 Y, # pointer to the output 

98 INV_RMS, # pointer to inverse rms 

99 X, # pointer to the input 

100 W, # pointer to the weights 

101 y_stride_r, 

102 y_stride_c, 

103 x_stride_r, # how much to increase the pointer when moving by 1 row 

104 x_stride_c, # how much to increase the pointer when moving by 1 col 

105 N, # number of columns in X 

106 eps, # epsilon to avoid division by zero 

107 M, # number of rows in X 

108 BLOCK_SIZE: tl.constexpr, 

109): 

110 prog_num = tl.num_programs(0).to(tl.uint64) 

111 task_num = M 

112 pid = tl.program_id(0).to(tl.uint64) 

113 while pid < task_num: 

114 Y_ = Y + pid * y_stride_r 

115 X_ = X + pid * x_stride_r 

116 

117 mask = tl.arange(0, BLOCK_SIZE) < N 

118 cols = tl.arange(0, BLOCK_SIZE) 

119 x = tl.load(X_ + cols * x_stride_c, mask, other=0.0).to(tl.float32) 

120 

121 var = tl.sum(x * x, axis=0) / N 

122 rrms = 1 / tl.sqrt(var + eps) 

123 

124 w = tl.load(W + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0) 

125 y = (x * rrms).to(Y_.dtype.element_ty) * w 

126 tl.store(Y_ + cols * y_stride_c, y, mask=mask) 

127 tl.store(INV_RMS + pid, rrms) 

128 pid += prog_num 

129 

130 

131@libentry() 

132@triton.autotune( 

133 configs=cfggen_reduce_op(), 

134 key=["N"], 

135) 

136@triton.jit(do_not_specialize=["eps"]) 

137def rms_norm_kernel_C_split( 

138 Y, # pointer to the output 

139 INV_RMS, # pointer to inverse rms 

140 X, # pointer to the input 

141 W, # pointer to the weights 

142 y_stride_r, 

143 y_stride_c, 

144 x_stride_r, # how much to increase the pointer when moving by 1 row 

145 x_stride_c, # how much to increase the pointer when moving by 1 col 

146 N, # number of columns in X 

147 eps, # epsilon to avoid division by zero 

148 M, # number of rows in X 

149 BLOCK_SIZE: tl.constexpr, 

150): 

151 prog_num = tl.num_programs(0).to(tl.uint64) 

152 task_num = M 

153 pid = tl.program_id(0).to(tl.uint64) 

154 while pid < task_num: 

155 Y_ = Y + pid * y_stride_r 

156 X_ = X + pid * x_stride_r 

157 

158 var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) 

159 for m_idx in range(0, N, BLOCK_SIZE): 

160 cols = m_idx + tl.arange(0, BLOCK_SIZE) 

161 mask = cols < N 

162 x = tl.load(X_ + cols * x_stride_c, mask, other=0.0).to(tl.float32) 

163 var += x * x 

164 

165 var = tl.sum(var, axis=0) / N 

166 rrms = 1 / tl.sqrt(var + eps) 

167 

168 for m_idx in range(0, N, BLOCK_SIZE): 

169 cols = m_idx + tl.arange(0, BLOCK_SIZE) 

170 mask = cols < N 

171 w = tl.load(W + cols, mask=mask, other=0.0) 

172 x = tl.load(X_ + cols * x_stride_c, mask, other=0.0).to(tl.float32) 

173 y = (x * rrms).to(Y_.dtype.element_ty) * w 

174 tl.store(Y_ + cols * y_stride_c, y, mask=mask) 

175 tl.store(INV_RMS + pid, rrms) 

176 pid += prog_num 

177 

178 

179@libentry() 

180@triton.jit(do_not_specialize=["eps"]) 

181def rms_norm_grad_dx_kernel( 

182 X, # pointer to the input 

183 DY, 

184 INV_RMS, # pointer to inverse rms 

185 DX, # pointer to the output 

186 W, # pointer to the weights 

187 dx_stride_r, 

188 dx_stride_c, 

189 x_stride_r, # how much to increase the pointer when moving by 1 row 

190 x_stride_c, # how much to increase the pointer when moving by 1 col 

191 N, # number of columns in X 

192 eps, # epsilon to avoid division by zero 

193 M, # number of rows in X 

194 BLOCK_SIZE: tl.constexpr, 

195): 

196 prog_num = tl.num_programs(0).to(tl.uint64) 

197 task_num = M 

198 pid = tl.program_id(0).to(tl.uint64) 

199 while pid < task_num: 

200 DX_ = DX + pid * dx_stride_r 

201 X_ = X + pid * x_stride_r 

202 DY_ = DY + pid * x_stride_r 

203 INV_RMS_ = INV_RMS + pid 

204 

205 mask = tl.arange(0, BLOCK_SIZE) < N 

206 cols = tl.arange(0, BLOCK_SIZE) 

207 x = tl.load(X_ + cols * x_stride_c, mask, other=0.0).to(tl.float32) 

208 inv_rms = tl.load(INV_RMS_).to(tl.float32) 

209 dy = tl.load(DY_ + cols * x_stride_c, mask, other=0.0).to(tl.float32) 

210 w = tl.load(W + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0) 

211 

212 dy = dy * w 

213 

214 normalized_buf = x * inv_rms 

215 row_sum_stats = tl.sum(normalized_buf * dy, axis=0) 

216 

217 norm_val = normalized_buf / N 

218 dx = (dy - norm_val * row_sum_stats) * inv_rms 

219 

220 tl.store(DX_ + cols * dx_stride_c, dx, mask=mask) 

221 pid += prog_num 

222 

223 

224@libentry() 

225@triton.autotune( 

226 configs=cfggen_reduce_op(), 

227 key=["N"], 

228) 

229@triton.jit(do_not_specialize=["eps"]) 

230def rms_norm_grad_dx_kernel_C_split( 

231 X, # pointer to the input 

232 DY, 

233 INV_RMS, # pointer to inverse rms 

234 DX, # pointer to the output 

235 W, # pointer to the weights 

236 dx_stride_r, 

237 dx_stride_c, 

238 x_stride_r, # how much to increase the pointer when moving by 1 row 

239 x_stride_c, # how much to increase the pointer when moving by 1 col 

240 N, # number of columns in X 

241 eps, # epsilon to avoid division by zero 

242 M, # number of rows in X 

243 BLOCK_SIZE: tl.constexpr, 

244): 

245 prog_num = tl.num_programs(0).to(tl.uint64) 

246 task_num = M 

247 pid = tl.program_id(0).to(tl.uint64) 

248 while pid < task_num: 

249 DX_ = DX + pid * dx_stride_r 

250 X_ = X + pid * x_stride_r 

251 DY_ = DY + pid * x_stride_r 

252 INV_RMS_ = INV_RMS + pid 

253 inv_rms = tl.load(INV_RMS_).to(tl.float32) 

254 

255 acc = tl.zeros([BLOCK_SIZE], dtype=tl.float32) 

256 for m_idx in range(0, N, BLOCK_SIZE): 

257 cols = m_idx + tl.arange(0, BLOCK_SIZE) 

258 mask = cols < N 

259 x = tl.load(X_ + cols * x_stride_c, mask=mask, other=0.0).to(tl.float32) 

260 inv_rms = tl.load(INV_RMS_).to(tl.float32) 

261 dy = tl.load(DY_ + cols * x_stride_c, mask=mask, other=0.0).to(tl.float32) 

262 w = tl.load(W + cols, mask=mask, other=0.0) 

263 dy = dy * w 

264 normalized = x * inv_rms 

265 acc += normalized * dy 

266 

267 row_sum_stats = tl.sum(acc, axis=0) 

268 

269 for m_idx in range(0, N, BLOCK_SIZE): 

270 cols = m_idx + tl.arange(0, BLOCK_SIZE) 

271 mask = cols < N 

272 x = tl.load(X_ + cols * x_stride_c, mask=mask, other=0.0).to(tl.float32) 

273 inv_rms = tl.load(INV_RMS_).to(tl.float32) 

274 dy = tl.load(DY_ + cols * x_stride_c, mask=mask, other=0.0).to(tl.float32) 

275 w = tl.load(W + cols, mask=mask, other=0.0) 

276 dy = dy * w 

277 normalized = x * inv_rms 

278 norm_val = normalized / N 

279 dx = (dy - norm_val * row_sum_stats) * inv_rms 

280 tl.store(DX_ + cols * dx_stride_c, dx, mask=mask) 

281 pid += prog_num 

282 

283 

284@libentry() 

285@triton.jit 

286def rms_norm_grad_dw_kernel( 

287 X, # pointer to the input 

288 DY, 

289 INV_RMS, # pointer to inverse rms 

290 DW, # pointer to the output 

291 dx_stride_r, 

292 dx_stride_c, 

293 x_stride_r, # how much to increase the pointer when moving by 1 row 

294 x_stride_c, # how much to increase the pointer when moving by 1 col 

295 M, # number of rows in X 

296 N, # number of columns in X 

297 ROW_BLOCK_SIZE: tl.constexpr, 

298 COL_BLOCK_SIZE: tl.constexpr, 

299): 

300 row_pid = tl.program_id(0) 

301 col_pid = tl.program_id(1) 

302 

303 row_start = row_pid * ROW_BLOCK_SIZE 

304 col_start = col_pid * COL_BLOCK_SIZE 

305 

306 offset = row_start * x_stride_r + col_start * x_stride_c 

307 X += offset 

308 DY += offset 

309 INV_RMS += row_start 

310 

311 rows = tl.arange(0, ROW_BLOCK_SIZE) 

312 cols = tl.arange(0, COL_BLOCK_SIZE) 

313 

314 row_mask = (row_start + rows) < M 

315 col_mask = (col_start + cols) < N 

316 

317 x = tl.load( 

318 X + rows[:, None] * x_stride_r + cols[None, :] * x_stride_c, 

319 row_mask[:, None] & col_mask[None, :], 

320 other=0.0, 

321 ).to(tl.float32) 

322 inv_rms = tl.load(INV_RMS + rows, row_mask, other=0.0).to(tl.float32) 

323 dy = tl.load( 

324 DY + rows[:, None] * x_stride_r + cols[None, :] * x_stride_c, 

325 row_mask[:, None] & col_mask[None, :], 

326 other=0.0, 

327 ).to(tl.float32) 

328 

329 d_weight = x * dy * inv_rms[:, None] 

330 partial_dweight_sum = tl.sum(d_weight, axis=0) 

331 

332 tl.store( 

333 DW + row_pid * N + col_start + cols, 

334 partial_dweight_sum, 

335 mask=col_mask, 

336 ) 

337 

338 

339class RmsNorm(torch.autograd.Function): 

340 @staticmethod 

341 def forward(ctx, x, normalized_shape, weight, eps=1e-5): 

342 y, inv_rms = rms_norm_forward(x, normalized_shape, weight, eps) 

343 ctx.save_for_backward(x, inv_rms, weight) 

344 ctx.normalized_shape = normalized_shape 

345 ctx.eps = eps 

346 return y 

347 

348 @staticmethod 

349 def backward(ctx, dy): 

350 x, inv_rms, weight = ctx.saved_tensors 

351 normalized_shape = ctx.normalized_shape 

352 eps = ctx.eps 

353 dx, dw = rms_norm_backward(dy, x, inv_rms, normalized_shape, weight, eps) 

354 return dx, None, dw, None 

355 

356 

357def rms_norm(x, normalized_shape, weight, eps=1e-5): 

358 return RmsNorm.apply(x, normalized_shape, weight, eps)