Coverage for src/flag_gems/runtime/backend/_ascend/fla/layernorm_guard.py: 0%
96 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/modules/layernorm_gated.py
2# Copyright (c) 2024, Tri Dao.
3# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
4# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
5# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
6# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
7# mypy: ignore-errors
9import torch
10import triton
11import triton.language as tl
13MAX_CORES = 65535
16@triton.heuristics(
17 {
18 "HAS_BIAS": lambda args: args["B"] is not None,
19 "HAS_Z": lambda args: args["Z"] is not None,
20 }
21)
22@triton.jit
23def layer_norm_fwd_kernel(
24 X, # pointer to the input
25 Y, # pointer to the output
26 W, # pointer to the weights
27 B, # pointer to the biases
28 Z, # pointer to the other branch
29 Mean, # pointer to the mean
30 Rstd, # pointer to the 1/std
31 stride_x_row, # how much to increase the pointer when moving by 1 row
32 stride_y_row,
33 stride_z_row,
34 M, # number of rows in X_base
35 N, # number of columns in X_base
36 eps, # epsilon to avoid division by zero
37 BLOCK_N: tl.constexpr,
38 HAS_BIAS: tl.constexpr,
39 HAS_Z: tl.constexpr,
40 NORM_BEFORE_GATE: tl.constexpr,
41 IS_RMS_NORM: tl.constexpr,
42 N_CORES: tl.constexpr,
43):
44 # Map the program id to the row of X_base and Y_base it should compute.
45 row = tl.program_id(0)
46 group = tl.program_id(1)
48 BLOCK_ROWS = M if M < N_CORES else N_CORES
49 n_iters = M // BLOCK_ROWS
50 remain = M % BLOCK_ROWS
51 if row < remain:
52 n_iters = n_iters + 1
54 for i in tl.range(n_iters):
55 X_base = X + (i * BLOCK_ROWS * stride_x_row) + row * stride_x_row + group * N
56 Y_base = Y + (i * BLOCK_ROWS * stride_y_row) + row * stride_y_row + group * N
57 if HAS_Z:
58 Z_base = (
59 Z + (i * BLOCK_ROWS * stride_z_row) + row * stride_z_row + group * N
60 )
61 if not IS_RMS_NORM:
62 Mean_base = Mean + (i * BLOCK_ROWS) + group * M
63 Rstd_base = Rstd + (i * BLOCK_ROWS) + group * M
64 W_base = W + group * N
65 if HAS_BIAS:
66 B_base = B + group * N
67 # Compute mean and variance
68 cols = tl.arange(0, BLOCK_N)
69 x = tl.load(X_base + cols, mask=cols < N, other=0.0).to(tl.float32)
70 if HAS_Z and not NORM_BEFORE_GATE:
71 z = tl.load(Z_base + cols, mask=cols < N).to(tl.float32)
72 x *= z * tl.sigmoid(z)
73 if not IS_RMS_NORM:
74 mean = tl.sum(x, axis=0) / N
75 tl.store(Mean_base + row, mean)
76 xbar = tl.where(cols < N, x - mean, 0.0)
77 var = tl.sum(xbar * xbar, axis=0) / N
78 else:
79 xbar = tl.where(cols < N, x, 0.0)
80 var = tl.sum(xbar * xbar, axis=0) / N
81 rstd = 1 / tl.sqrt(var + eps)
82 tl.store(Rstd_base + row, rstd)
83 # Normalize and apply linear transformation
84 mask = cols < N
85 w = tl.load(W_base + cols, mask=mask).to(tl.float32)
86 if HAS_BIAS:
87 b = tl.load(B_base + cols, mask=mask).to(tl.float32)
88 x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
89 y = x_hat * w + b if HAS_BIAS else x_hat * w
90 if HAS_Z and NORM_BEFORE_GATE:
91 z = tl.load(Z_base + cols, mask=mask).to(tl.float32)
92 y *= z * tl.sigmoid(z)
93 # Write output
94 tl.store(Y_base + cols, y, mask=mask)
97def _layer_norm_fwd(
98 x,
99 weight,
100 bias,
101 eps,
102 z=None,
103 out=None,
104 group_size=None,
105 norm_before_gate=True,
106 is_rms_norm=False,
107):
108 M, N = x.shape
109 if group_size is None:
110 group_size = N
111 assert N % group_size == 0
112 ngroups = N // group_size
113 assert x.stride(-1) == 1
114 if z is not None:
115 assert z.stride(-1) == 1
116 assert z.shape == (M, N)
117 assert weight.shape == (N,)
118 assert weight.stride(-1) == 1
119 if bias is not None:
120 assert bias.stride(-1) == 1
121 assert bias.shape == (N,)
122 # allocate output
123 if out is not None:
124 assert out.shape == x.shape
125 else:
126 out = torch.empty_like(x)
127 assert out.stride(-1) == 1
128 mean = (
129 torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
130 if not is_rms_norm
131 else None
132 )
133 rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
134 # Less than 64KB per feature: enqueue fused kernel
135 MAX_FUSED_SIZE = 65536 // x.element_size()
136 BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
137 if group_size > BLOCK_N:
138 raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
139 # heuristics for number of warps
140 num_warps = min(max(BLOCK_N // 256, 1), 8)
141 grid = (M if M < MAX_CORES else MAX_CORES, ngroups)
142 with torch.npu.device(x.device.index):
143 layer_norm_fwd_kernel[grid](
144 x,
145 out,
146 weight,
147 bias,
148 z,
149 mean,
150 rstd,
151 x.stride(0),
152 out.stride(0),
153 z.stride(0) if z is not None else 0,
154 M,
155 group_size,
156 eps,
157 BLOCK_N=BLOCK_N,
158 NORM_BEFORE_GATE=norm_before_gate,
159 IS_RMS_NORM=is_rms_norm,
160 N_CORES=MAX_CORES,
161 num_warps=num_warps,
162 )
163 return out, mean, rstd
166class LayerNormFn(torch.autograd.Function):
167 @staticmethod
168 def forward(
169 ctx,
170 x,
171 weight,
172 bias,
173 z=None,
174 eps=1e-6,
175 group_size=None,
176 norm_before_gate=True,
177 is_rms_norm=False,
178 ):
179 """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
181 x_shape_og = x.shape
182 # reshape input data into 2D tensor
183 x = x.reshape(-1, x.shape[-1])
184 if x.stride(-1) != 1:
185 x = x.contiguous()
186 if z is not None:
187 assert z.shape == x_shape_og
188 z = z.reshape(-1, z.shape[-1])
189 if z.stride(-1) != 1:
190 z = z.contiguous()
191 weight = weight.contiguous()
192 if bias is not None:
193 bias = bias.contiguous()
194 y, mean, rstd = _layer_norm_fwd(
195 x,
196 weight,
197 bias,
198 eps,
199 z=z,
200 group_size=group_size,
201 norm_before_gate=norm_before_gate,
202 is_rms_norm=is_rms_norm,
203 )
204 return y.reshape(x_shape_og)