Coverage for src/flag_gems/runtime/backend/_cambricon/ops/rms_norm.py: 0%
199 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
11from ..utils import MAX_GRID_SIZE_X, cfggen_reduce_op
13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
14MAX_NRAM_C_FORWARD = 16384 * 2
17def rms_norm_forward(x, normalized_shape, weight, eps=1e-5):
18 logger.debug("GEMS_CAMBRICON RMSNORM FORWARD")
19 dim = x.ndim - len(normalized_shape)
20 M = math.prod(x.shape[:dim])
21 N = math.prod(normalized_shape)
23 BLOCK_SIZE = N # triton.next_power_of_2(N)
24 x = x.contiguous()
25 weight = weight.contiguous()
26 y = torch.empty_like(x)
27 inv_rms = torch.empty((M,), device=x.device, dtype=torch.float32)
28 grid = (min(M, MAX_GRID_SIZE_X // 4),)
29 with torch_device_fn.device(x.device):
30 if BLOCK_SIZE <= MAX_NRAM_C_FORWARD:
31 logger.debug("GEMS_CAMBRICON RMSNORM FORWARD NOT USING C SPLIT")
32 rms_norm_kernel[grid](
33 y, inv_rms, x, weight, N, 1, N, 1, N, eps, M, BLOCK_SIZE
34 )
35 else:
36 logger.debug("GEMS_CAMBRICON RMSNORM FORWARD USING C SPLIT")
37 rms_norm_kernel_C_split[grid](y, inv_rms, x, weight, N, 1, N, 1, N, eps, M)
38 return y, inv_rms
41def rms_norm_backward(dy, x, inv_rms, normalized_shape, weight, eps=1e-5):
42 logger.debug("GEMS_CAMBRICON RMSNORM BACKWARD")
43 dim = x.ndim - len(normalized_shape)
44 M = math.prod(x.shape[:dim])
45 N = math.prod(normalized_shape)
47 # BLOCK_SIZE = triton.next_power_of_2(N)
48 BLOCK_SIZE = N
49 x = x.contiguous()
50 weight = weight.contiguous()
51 dx = torch.empty_like(x)
52 grid = (min(M, MAX_GRID_SIZE_X // 4),)
53 with torch_device_fn.device(x.device):
54 if BLOCK_SIZE <= MAX_NRAM_C_FORWARD:
55 logger.debug("GEMS_CAMBRICON RMSNORM BACKWARD NOT USING C SPLIT")
56 rms_norm_grad_dx_kernel[grid](
57 x, dy, inv_rms, dx, weight, N, 1, N, 1, N, eps, M, BLOCK_SIZE
58 )
59 else:
60 logger.debug("GEMS_CAMBRICON RMSNORM BACKWARD USING C SPLIT")
61 rms_norm_grad_dx_kernel_C_split[grid](
62 x, dy, inv_rms, dx, weight, N, 1, N, 1, N, eps, M
63 )
65 ROW_BLOCK_SIZE = 16
66 COL_BLOCK_SIZE = 256
67 row_block_num = triton.cdiv(M, ROW_BLOCK_SIZE)
68 col_block_num = triton.cdiv(N, COL_BLOCK_SIZE)
70 partial_buffer = torch.empty(
71 (row_block_num, N), dtype=torch.float32, device=x.device
72 )
74 with torch_device_fn.device(x.device):
75 rms_norm_grad_dw_kernel[row_block_num, col_block_num](
76 x,
77 dy,
78 inv_rms,
79 partial_buffer,
80 N,
81 1,
82 N,
83 1,
84 M,
85 N,
86 ROW_BLOCK_SIZE,
87 COL_BLOCK_SIZE,
88 )
89 dw = torch.sum(partial_buffer, dim=0, dtype=x.dtype).reshape(-1)
91 return dx, dw
94@libentry()
95@triton.jit(do_not_specialize=["eps"])
96def rms_norm_kernel(
97 Y, # pointer to the output
98 INV_RMS, # pointer to inverse rms
99 X, # pointer to the input
100 W, # pointer to the weights
101 y_stride_r,
102 y_stride_c,
103 x_stride_r, # how much to increase the pointer when moving by 1 row
104 x_stride_c, # how much to increase the pointer when moving by 1 col
105 N, # number of columns in X
106 eps, # epsilon to avoid division by zero
107 M, # number of rows in X
108 BLOCK_SIZE: tl.constexpr,
109):
110 prog_num = tl.num_programs(0).to(tl.uint64)
111 task_num = M
112 pid = tl.program_id(0).to(tl.uint64)
113 while pid < task_num:
114 Y_ = Y + pid * y_stride_r
115 X_ = X + pid * x_stride_r
117 mask = tl.arange(0, BLOCK_SIZE) < N
118 cols = tl.arange(0, BLOCK_SIZE)
119 x = tl.load(X_ + cols * x_stride_c, mask, other=0.0).to(tl.float32)
121 var = tl.sum(x * x, axis=0) / N
122 rrms = 1 / tl.sqrt(var + eps)
124 w = tl.load(W + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0)
125 y = (x * rrms).to(Y_.dtype.element_ty) * w
126 tl.store(Y_ + cols * y_stride_c, y, mask=mask)
127 tl.store(INV_RMS + pid, rrms)
128 pid += prog_num
131@libentry()
132@triton.autotune(
133 configs=cfggen_reduce_op(),
134 key=["N"],
135)
136@triton.jit(do_not_specialize=["eps"])
137def rms_norm_kernel_C_split(
138 Y, # pointer to the output
139 INV_RMS, # pointer to inverse rms
140 X, # pointer to the input
141 W, # pointer to the weights
142 y_stride_r,
143 y_stride_c,
144 x_stride_r, # how much to increase the pointer when moving by 1 row
145 x_stride_c, # how much to increase the pointer when moving by 1 col
146 N, # number of columns in X
147 eps, # epsilon to avoid division by zero
148 M, # number of rows in X
149 BLOCK_SIZE: tl.constexpr,
150):
151 prog_num = tl.num_programs(0).to(tl.uint64)
152 task_num = M
153 pid = tl.program_id(0).to(tl.uint64)
154 while pid < task_num:
155 Y_ = Y + pid * y_stride_r
156 X_ = X + pid * x_stride_r
158 var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
159 for m_idx in range(0, N, BLOCK_SIZE):
160 cols = m_idx + tl.arange(0, BLOCK_SIZE)
161 mask = cols < N
162 x = tl.load(X_ + cols * x_stride_c, mask, other=0.0).to(tl.float32)
163 var += x * x
165 var = tl.sum(var, axis=0) / N
166 rrms = 1 / tl.sqrt(var + eps)
168 for m_idx in range(0, N, BLOCK_SIZE):
169 cols = m_idx + tl.arange(0, BLOCK_SIZE)
170 mask = cols < N
171 w = tl.load(W + cols, mask=mask, other=0.0)
172 x = tl.load(X_ + cols * x_stride_c, mask, other=0.0).to(tl.float32)
173 y = (x * rrms).to(Y_.dtype.element_ty) * w
174 tl.store(Y_ + cols * y_stride_c, y, mask=mask)
175 tl.store(INV_RMS + pid, rrms)
176 pid += prog_num
179@libentry()
180@triton.jit(do_not_specialize=["eps"])
181def rms_norm_grad_dx_kernel(
182 X, # pointer to the input
183 DY,
184 INV_RMS, # pointer to inverse rms
185 DX, # pointer to the output
186 W, # pointer to the weights
187 dx_stride_r,
188 dx_stride_c,
189 x_stride_r, # how much to increase the pointer when moving by 1 row
190 x_stride_c, # how much to increase the pointer when moving by 1 col
191 N, # number of columns in X
192 eps, # epsilon to avoid division by zero
193 M, # number of rows in X
194 BLOCK_SIZE: tl.constexpr,
195):
196 prog_num = tl.num_programs(0).to(tl.uint64)
197 task_num = M
198 pid = tl.program_id(0).to(tl.uint64)
199 while pid < task_num:
200 DX_ = DX + pid * dx_stride_r
201 X_ = X + pid * x_stride_r
202 DY_ = DY + pid * x_stride_r
203 INV_RMS_ = INV_RMS + pid
205 mask = tl.arange(0, BLOCK_SIZE) < N
206 cols = tl.arange(0, BLOCK_SIZE)
207 x = tl.load(X_ + cols * x_stride_c, mask, other=0.0).to(tl.float32)
208 inv_rms = tl.load(INV_RMS_).to(tl.float32)
209 dy = tl.load(DY_ + cols * x_stride_c, mask, other=0.0).to(tl.float32)
210 w = tl.load(W + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0)
212 dy = dy * w
214 normalized_buf = x * inv_rms
215 row_sum_stats = tl.sum(normalized_buf * dy, axis=0)
217 norm_val = normalized_buf / N
218 dx = (dy - norm_val * row_sum_stats) * inv_rms
220 tl.store(DX_ + cols * dx_stride_c, dx, mask=mask)
221 pid += prog_num
224@libentry()
225@triton.autotune(
226 configs=cfggen_reduce_op(),
227 key=["N"],
228)
229@triton.jit(do_not_specialize=["eps"])
230def rms_norm_grad_dx_kernel_C_split(
231 X, # pointer to the input
232 DY,
233 INV_RMS, # pointer to inverse rms
234 DX, # pointer to the output
235 W, # pointer to the weights
236 dx_stride_r,
237 dx_stride_c,
238 x_stride_r, # how much to increase the pointer when moving by 1 row
239 x_stride_c, # how much to increase the pointer when moving by 1 col
240 N, # number of columns in X
241 eps, # epsilon to avoid division by zero
242 M, # number of rows in X
243 BLOCK_SIZE: tl.constexpr,
244):
245 prog_num = tl.num_programs(0).to(tl.uint64)
246 task_num = M
247 pid = tl.program_id(0).to(tl.uint64)
248 while pid < task_num:
249 DX_ = DX + pid * dx_stride_r
250 X_ = X + pid * x_stride_r
251 DY_ = DY + pid * x_stride_r
252 INV_RMS_ = INV_RMS + pid
253 inv_rms = tl.load(INV_RMS_).to(tl.float32)
255 acc = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
256 for m_idx in range(0, N, BLOCK_SIZE):
257 cols = m_idx + tl.arange(0, BLOCK_SIZE)
258 mask = cols < N
259 x = tl.load(X_ + cols * x_stride_c, mask=mask, other=0.0).to(tl.float32)
260 inv_rms = tl.load(INV_RMS_).to(tl.float32)
261 dy = tl.load(DY_ + cols * x_stride_c, mask=mask, other=0.0).to(tl.float32)
262 w = tl.load(W + cols, mask=mask, other=0.0)
263 dy = dy * w
264 normalized = x * inv_rms
265 acc += normalized * dy
267 row_sum_stats = tl.sum(acc, axis=0)
269 for m_idx in range(0, N, BLOCK_SIZE):
270 cols = m_idx + tl.arange(0, BLOCK_SIZE)
271 mask = cols < N
272 x = tl.load(X_ + cols * x_stride_c, mask=mask, other=0.0).to(tl.float32)
273 inv_rms = tl.load(INV_RMS_).to(tl.float32)
274 dy = tl.load(DY_ + cols * x_stride_c, mask=mask, other=0.0).to(tl.float32)
275 w = tl.load(W + cols, mask=mask, other=0.0)
276 dy = dy * w
277 normalized = x * inv_rms
278 norm_val = normalized / N
279 dx = (dy - norm_val * row_sum_stats) * inv_rms
280 tl.store(DX_ + cols * dx_stride_c, dx, mask=mask)
281 pid += prog_num
284@libentry()
285@triton.jit
286def rms_norm_grad_dw_kernel(
287 X, # pointer to the input
288 DY,
289 INV_RMS, # pointer to inverse rms
290 DW, # pointer to the output
291 dx_stride_r,
292 dx_stride_c,
293 x_stride_r, # how much to increase the pointer when moving by 1 row
294 x_stride_c, # how much to increase the pointer when moving by 1 col
295 M, # number of rows in X
296 N, # number of columns in X
297 ROW_BLOCK_SIZE: tl.constexpr,
298 COL_BLOCK_SIZE: tl.constexpr,
299):
300 row_pid = tl.program_id(0)
301 col_pid = tl.program_id(1)
303 row_start = row_pid * ROW_BLOCK_SIZE
304 col_start = col_pid * COL_BLOCK_SIZE
306 offset = row_start * x_stride_r + col_start * x_stride_c
307 X += offset
308 DY += offset
309 INV_RMS += row_start
311 rows = tl.arange(0, ROW_BLOCK_SIZE)
312 cols = tl.arange(0, COL_BLOCK_SIZE)
314 row_mask = (row_start + rows) < M
315 col_mask = (col_start + cols) < N
317 x = tl.load(
318 X + rows[:, None] * x_stride_r + cols[None, :] * x_stride_c,
319 row_mask[:, None] & col_mask[None, :],
320 other=0.0,
321 ).to(tl.float32)
322 inv_rms = tl.load(INV_RMS + rows, row_mask, other=0.0).to(tl.float32)
323 dy = tl.load(
324 DY + rows[:, None] * x_stride_r + cols[None, :] * x_stride_c,
325 row_mask[:, None] & col_mask[None, :],
326 other=0.0,
327 ).to(tl.float32)
329 d_weight = x * dy * inv_rms[:, None]
330 partial_dweight_sum = tl.sum(d_weight, axis=0)
332 tl.store(
333 DW + row_pid * N + col_start + cols,
334 partial_dweight_sum,
335 mask=col_mask,
336 )
339class RmsNorm(torch.autograd.Function):
340 @staticmethod
341 def forward(ctx, x, normalized_shape, weight, eps=1e-5):
342 y, inv_rms = rms_norm_forward(x, normalized_shape, weight, eps)
343 ctx.save_for_backward(x, inv_rms, weight)
344 ctx.normalized_shape = normalized_shape
345 ctx.eps = eps
346 return y
348 @staticmethod
349 def backward(ctx, dy):
350 x, inv_rms, weight = ctx.saved_tensors
351 normalized_shape = ctx.normalized_shape
352 eps = ctx.eps
353 dx, dw = rms_norm_backward(dy, x, inv_rms, normalized_shape, weight, eps)
354 return dx, None, dw, None
357def rms_norm(x, normalized_shape, weight, eps=1e-5):
358 return RmsNorm.apply(x, normalized_shape, weight, eps)