Coverage for src/flag_gems/runtime/backend/_ascend/ops/rms_norm.py: 0%
125 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
15@libentry()
16@triton.jit(do_not_specialize=["eps"])
17def rms_norm_kernel(
18 Y, # pointer to the output
19 INV_RMS, # pointer to inverse rms
20 X, # pointer to the input
21 W, # pointer to the weights
22 y_stride_r,
23 y_stride_c,
24 x_stride_r, # how much to increase the pointer when moving by 1 row
25 x_stride_c, # how much to increase the pointer when moving by 1 col
26 N, # number of columns in X
27 eps, # epsilon to avoid division by zero
28 BLOCK_SIZE: tl.constexpr,
29):
30 pid = tle.program_id(0)
31 Y += pid * y_stride_r
32 X += pid * x_stride_r
34 var = 0.0
35 for off in range(0, N, BLOCK_SIZE):
36 cols = off + tl.arange(0, BLOCK_SIZE)
37 mask = cols < N
38 x = tl.load(X + cols, mask, other=0.0).to(tl.float32)
39 var += tl.sum(x * x / N)
41 rrms = 1 / tl.sqrt(var + eps)
43 for off in range(0, N, BLOCK_SIZE):
44 cols = off + tl.arange(0, BLOCK_SIZE)
45 mask = cols < N
46 x = tl.load(X + cols, mask, other=0.0).to(tl.float32)
47 w = tl.load(W + cols, mask, other=0.0)
48 y = (x * rrms).to(Y.dtype.element_ty) * w
49 tl.store(Y + cols * y_stride_c, y, mask=mask)
51 tl.store(INV_RMS + pid, rrms)
54@libentry()
55@triton.jit(do_not_specialize=["eps"])
56def rms_norm_grad_dx_kernel(
57 X, # pointer to the input
58 DY,
59 INV_RMS, # pointer to inverse rms
60 DX, # pointer to the output
61 W, # pointer to the weights
62 dx_stride_r,
63 dx_stride_c,
64 x_stride_r, # how much to increase the pointer when moving by 1 row
65 x_stride_c, # how much to increase the pointer when moving by 1 col
66 N, # number of columns in X
67 eps, # epsilon to avoid division by zero
68 BLOCK_SIZE: tl.constexpr,
69):
70 pid = tle.program_id(0)
71 DX += pid * dx_stride_r
72 X += pid * x_stride_r
73 DY += pid * x_stride_r
74 INV_RMS += pid
76 inv_rms = tl.load(INV_RMS).to(tl.float32)
78 row_sum_stats = 0.0
79 for off in range(0, N, BLOCK_SIZE):
80 cols = off + tl.arange(0, BLOCK_SIZE)
81 mask = cols < N
82 x = tl.load(X + cols, mask, other=0.0).to(tl.float32)
83 dy = tl.load(DY + cols, mask, other=0.0).to(tl.float32)
84 w = tl.load(W + cols, mask, other=0.0).to(tl.float32)
85 dy = dy * w
86 normalized_buf = x * inv_rms
87 row_sum_stats += tl.sum(normalized_buf * dy)
89 for off in range(0, N, BLOCK_SIZE):
90 cols = off + tl.arange(0, BLOCK_SIZE)
91 mask = cols < N
92 x = tl.load(X + cols, mask, other=0.0).to(tl.float32)
93 dy = tl.load(DY + cols, mask, other=0.0).to(tl.float32)
94 w = tl.load(W + cols, mask, other=0.0).to(tl.float32)
95 dy = dy * w
96 normalized_buf = x * inv_rms
97 norm_val = normalized_buf / N
98 dx = (dy - norm_val * row_sum_stats) * inv_rms
99 tl.store(DX + cols * dx_stride_c, dx, mask=mask)
102@libentry()
103@triton.jit
104def rms_norm_grad_dw_kernel(
105 X, # pointer to the input
106 DY,
107 INV_RMS, # pointer to inverse rms
108 DW, # pointer to the output
109 dx_stride_r,
110 dx_stride_c,
111 x_stride_r, # how much to increase the pointer when moving by 1 row
112 x_stride_c, # how much to increase the pointer when moving by 1 col
113 M, # number of rows in X
114 N, # number of columns in X
115 ROW_BLOCK_SIZE: tl.constexpr,
116 COL_BLOCK_SIZE: tl.constexpr,
117):
118 row_pid = tl.program_id(0)
119 col_pid = tl.program_id(1)
121 row_start = row_pid * ROW_BLOCK_SIZE
122 col_start = col_pid * COL_BLOCK_SIZE
124 offset = row_start * x_stride_r + col_start * x_stride_c
125 X += offset
126 DY += offset
127 INV_RMS += row_start
129 rows = tl.arange(0, ROW_BLOCK_SIZE)
130 cols = tl.arange(0, COL_BLOCK_SIZE)
132 row_mask = (row_start + rows) < M
133 col_mask = (col_start + cols) < N
135 x = tl.load(
136 X + rows[:, None] * x_stride_r + cols[None, :] * x_stride_c,
137 row_mask[:, None] & col_mask[None, :],
138 other=0.0,
139 ).to(tl.float32)
140 inv_rms = tl.load(INV_RMS + rows, row_mask, other=0.0).to(tl.float32)
141 dy = tl.load(
142 DY + rows[:, None] * x_stride_r + cols[None, :] * x_stride_c,
143 row_mask[:, None] & col_mask[None, :],
144 other=0.0,
145 ).to(tl.float32)
147 d_weight = x * dy * inv_rms[:, None]
148 partial_dweight_sum = tl.sum(d_weight, axis=0)
150 tl.store(
151 DW + row_pid * N + col_start + cols,
152 partial_dweight_sum,
153 mask=col_mask,
154 )
157class RmsNorm(torch.autograd.Function):
158 @staticmethod
159 def forward(ctx, x, normalized_shape, weight, eps=1e-5):
160 logger.debug("GEMS_ASCEND LAYERNORM FORWARD")
161 dim = x.ndim - len(normalized_shape)
162 M = math.prod(x.shape[:dim])
163 N = math.prod(normalized_shape)
165 BLOCK_SIZE = min(triton.next_power_of_2(N), 12064)
167 x = x.contiguous()
168 weight = weight.contiguous()
169 y = torch.empty_like(x)
170 inv_rms = torch.empty((M,), device=x.device, dtype=torch.float32)
172 with torch_device_fn.device(x.device):
173 rms_norm_kernel[M,](y, inv_rms, x, weight, N, 1, N, 1, N, eps, BLOCK_SIZE)
175 ctx.save_for_backward(x, inv_rms, weight)
176 ctx.normalized_shape = normalized_shape
177 ctx.eps = eps
178 return y
180 @staticmethod
181 def backward(ctx, dy):
182 logger.debug("GEMS_ASCEND LAYERNORM BACKWARD")
183 x, inv_rms, weight = ctx.saved_tensors
184 normalized_shape = ctx.normalized_shape
185 eps = ctx.eps
187 dim = x.ndim - len(normalized_shape)
188 M = math.prod(x.shape[:dim])
189 N = math.prod(normalized_shape)
191 BLOCK_SIZE = min(triton.next_power_of_2(N), 6912)
192 x = x.contiguous()
193 weight = weight.contiguous()
194 dx = torch.empty_like(x)
196 with torch_device_fn.device(x.device):
197 rms_norm_grad_dx_kernel[M,](
198 x, dy, inv_rms, dx, weight, N, 1, N, 1, N, eps, BLOCK_SIZE
199 )
201 ROW_BLOCK_SIZE = 16
202 COL_BLOCK_SIZE = 256
203 row_block_num = triton.cdiv(M, ROW_BLOCK_SIZE)
204 col_block_num = triton.cdiv(N, COL_BLOCK_SIZE)
206 partial_buffer = torch.empty(
207 (row_block_num, N), dtype=torch.float32, device=x.device
208 )
210 with torch_device_fn.device(x.device):
211 rms_norm_grad_dw_kernel[row_block_num, col_block_num](
212 x,
213 dy,
214 inv_rms,
215 partial_buffer,
216 N,
217 1,
218 N,
219 1,
220 M,
221 N,
222 ROW_BLOCK_SIZE,
223 COL_BLOCK_SIZE,
224 )
225 dw = torch.sum(partial_buffer, dim=0, dtype=x.dtype).reshape(-1)
227 return dx, None, dw, None
230def rms_norm(x, normalized_shape, weight, eps=1e-5):
231 return RmsNorm.apply(x, normalized_shape, weight, eps)