Coverage for src/flag_gems/runtime/backend/_ascend/ops/rms_norm.py: 0%

125 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-16 02:02 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

13 

14 

15@libentry() 

16@triton.jit(do_not_specialize=["eps"]) 

17def rms_norm_kernel( 

18 Y, # pointer to the output 

19 INV_RMS, # pointer to inverse rms 

20 X, # pointer to the input 

21 W, # pointer to the weights 

22 y_stride_r, 

23 y_stride_c, 

24 x_stride_r, # how much to increase the pointer when moving by 1 row 

25 x_stride_c, # how much to increase the pointer when moving by 1 col 

26 N, # number of columns in X 

27 eps, # epsilon to avoid division by zero 

28 BLOCK_SIZE: tl.constexpr, 

29): 

30 pid = tle.program_id(0) 

31 Y += pid * y_stride_r 

32 X += pid * x_stride_r 

33 

34 var = 0.0 

35 for off in range(0, N, BLOCK_SIZE): 

36 cols = off + tl.arange(0, BLOCK_SIZE) 

37 mask = cols < N 

38 x = tl.load(X + cols, mask, other=0.0).to(tl.float32) 

39 var += tl.sum(x * x / N) 

40 

41 rrms = 1 / tl.sqrt(var + eps) 

42 

43 for off in range(0, N, BLOCK_SIZE): 

44 cols = off + tl.arange(0, BLOCK_SIZE) 

45 mask = cols < N 

46 x = tl.load(X + cols, mask, other=0.0).to(tl.float32) 

47 w = tl.load(W + cols, mask, other=0.0) 

48 y = (x * rrms).to(Y.dtype.element_ty) * w 

49 tl.store(Y + cols * y_stride_c, y, mask=mask) 

50 

51 tl.store(INV_RMS + pid, rrms) 

52 

53 

54@libentry() 

55@triton.jit(do_not_specialize=["eps"]) 

56def rms_norm_grad_dx_kernel( 

57 X, # pointer to the input 

58 DY, 

59 INV_RMS, # pointer to inverse rms 

60 DX, # pointer to the output 

61 W, # pointer to the weights 

62 dx_stride_r, 

63 dx_stride_c, 

64 x_stride_r, # how much to increase the pointer when moving by 1 row 

65 x_stride_c, # how much to increase the pointer when moving by 1 col 

66 N, # number of columns in X 

67 eps, # epsilon to avoid division by zero 

68 BLOCK_SIZE: tl.constexpr, 

69): 

70 pid = tle.program_id(0) 

71 DX += pid * dx_stride_r 

72 X += pid * x_stride_r 

73 DY += pid * x_stride_r 

74 INV_RMS += pid 

75 

76 inv_rms = tl.load(INV_RMS).to(tl.float32) 

77 

78 row_sum_stats = 0.0 

79 for off in range(0, N, BLOCK_SIZE): 

80 cols = off + tl.arange(0, BLOCK_SIZE) 

81 mask = cols < N 

82 x = tl.load(X + cols, mask, other=0.0).to(tl.float32) 

83 dy = tl.load(DY + cols, mask, other=0.0).to(tl.float32) 

84 w = tl.load(W + cols, mask, other=0.0).to(tl.float32) 

85 dy = dy * w 

86 normalized_buf = x * inv_rms 

87 row_sum_stats += tl.sum(normalized_buf * dy) 

88 

89 for off in range(0, N, BLOCK_SIZE): 

90 cols = off + tl.arange(0, BLOCK_SIZE) 

91 mask = cols < N 

92 x = tl.load(X + cols, mask, other=0.0).to(tl.float32) 

93 dy = tl.load(DY + cols, mask, other=0.0).to(tl.float32) 

94 w = tl.load(W + cols, mask, other=0.0).to(tl.float32) 

95 dy = dy * w 

96 normalized_buf = x * inv_rms 

97 norm_val = normalized_buf / N 

98 dx = (dy - norm_val * row_sum_stats) * inv_rms 

99 tl.store(DX + cols * dx_stride_c, dx, mask=mask) 

100 

101 

102@libentry() 

103@triton.jit 

104def rms_norm_grad_dw_kernel( 

105 X, # pointer to the input 

106 DY, 

107 INV_RMS, # pointer to inverse rms 

108 DW, # pointer to the output 

109 dx_stride_r, 

110 dx_stride_c, 

111 x_stride_r, # how much to increase the pointer when moving by 1 row 

112 x_stride_c, # how much to increase the pointer when moving by 1 col 

113 M, # number of rows in X 

114 N, # number of columns in X 

115 ROW_BLOCK_SIZE: tl.constexpr, 

116 COL_BLOCK_SIZE: tl.constexpr, 

117): 

118 row_pid = tl.program_id(0) 

119 col_pid = tl.program_id(1) 

120 

121 row_start = row_pid * ROW_BLOCK_SIZE 

122 col_start = col_pid * COL_BLOCK_SIZE 

123 

124 offset = row_start * x_stride_r + col_start * x_stride_c 

125 X += offset 

126 DY += offset 

127 INV_RMS += row_start 

128 

129 rows = tl.arange(0, ROW_BLOCK_SIZE) 

130 cols = tl.arange(0, COL_BLOCK_SIZE) 

131 

132 row_mask = (row_start + rows) < M 

133 col_mask = (col_start + cols) < N 

134 

135 x = tl.load( 

136 X + rows[:, None] * x_stride_r + cols[None, :] * x_stride_c, 

137 row_mask[:, None] & col_mask[None, :], 

138 other=0.0, 

139 ).to(tl.float32) 

140 inv_rms = tl.load(INV_RMS + rows, row_mask, other=0.0).to(tl.float32) 

141 dy = tl.load( 

142 DY + rows[:, None] * x_stride_r + cols[None, :] * x_stride_c, 

143 row_mask[:, None] & col_mask[None, :], 

144 other=0.0, 

145 ).to(tl.float32) 

146 

147 d_weight = x * dy * inv_rms[:, None] 

148 partial_dweight_sum = tl.sum(d_weight, axis=0) 

149 

150 tl.store( 

151 DW + row_pid * N + col_start + cols, 

152 partial_dweight_sum, 

153 mask=col_mask, 

154 ) 

155 

156 

157class RmsNorm(torch.autograd.Function): 

158 @staticmethod 

159 def forward(ctx, x, normalized_shape, weight, eps=1e-5): 

160 logger.debug("GEMS_ASCEND LAYERNORM FORWARD") 

161 dim = x.ndim - len(normalized_shape) 

162 M = math.prod(x.shape[:dim]) 

163 N = math.prod(normalized_shape) 

164 

165 BLOCK_SIZE = min(triton.next_power_of_2(N), 12064) 

166 

167 x = x.contiguous() 

168 weight = weight.contiguous() 

169 y = torch.empty_like(x) 

170 inv_rms = torch.empty((M,), device=x.device, dtype=torch.float32) 

171 

172 with torch_device_fn.device(x.device): 

173 rms_norm_kernel[M,](y, inv_rms, x, weight, N, 1, N, 1, N, eps, BLOCK_SIZE) 

174 

175 ctx.save_for_backward(x, inv_rms, weight) 

176 ctx.normalized_shape = normalized_shape 

177 ctx.eps = eps 

178 return y 

179 

180 @staticmethod 

181 def backward(ctx, dy): 

182 logger.debug("GEMS_ASCEND LAYERNORM BACKWARD") 

183 x, inv_rms, weight = ctx.saved_tensors 

184 normalized_shape = ctx.normalized_shape 

185 eps = ctx.eps 

186 

187 dim = x.ndim - len(normalized_shape) 

188 M = math.prod(x.shape[:dim]) 

189 N = math.prod(normalized_shape) 

190 

191 BLOCK_SIZE = min(triton.next_power_of_2(N), 6912) 

192 x = x.contiguous() 

193 weight = weight.contiguous() 

194 dx = torch.empty_like(x) 

195 

196 with torch_device_fn.device(x.device): 

197 rms_norm_grad_dx_kernel[M,]( 

198 x, dy, inv_rms, dx, weight, N, 1, N, 1, N, eps, BLOCK_SIZE 

199 ) 

200 

201 ROW_BLOCK_SIZE = 16 

202 COL_BLOCK_SIZE = 256 

203 row_block_num = triton.cdiv(M, ROW_BLOCK_SIZE) 

204 col_block_num = triton.cdiv(N, COL_BLOCK_SIZE) 

205 

206 partial_buffer = torch.empty( 

207 (row_block_num, N), dtype=torch.float32, device=x.device 

208 ) 

209 

210 with torch_device_fn.device(x.device): 

211 rms_norm_grad_dw_kernel[row_block_num, col_block_num]( 

212 x, 

213 dy, 

214 inv_rms, 

215 partial_buffer, 

216 N, 

217 1, 

218 N, 

219 1, 

220 M, 

221 N, 

222 ROW_BLOCK_SIZE, 

223 COL_BLOCK_SIZE, 

224 ) 

225 dw = torch.sum(partial_buffer, dim=0, dtype=x.dtype).reshape(-1) 

226 

227 return dx, None, dw, None 

228 

229 

230def rms_norm(x, normalized_shape, weight, eps=1e-5): 

231 return RmsNorm.apply(x, normalized_shape, weight, eps)