Coverage for src/flag_gems/ops/layernorm.py: 32%
241 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as tle
13logger = logging.getLogger(__name__)
16@triton.jit
17def prev_multiple_of(a, b):
18 # the largest x<a that x%b ==0
19 return tl.cdiv(a, b) * b - b
22@libentry()
23@triton.autotune(
24 configs=runtime.get_tuned_config("layer_norm_persistent"),
25 key=["M", "N"],
26)
27@triton.jit(do_not_specialize=["eps"])
28def layer_norm_persistent_kernel(
29 in_ptr,
30 out_ptr,
31 weight_ptr,
32 bias_ptr,
33 out_mean_ptr, # pointer to the mean
34 out_rstd_ptr, # pointer to the 1/std
35 M,
36 N,
37 eps,
38 TILE_N: tl.constexpr,
39):
40 # using 1d tile makes code clean
41 # Map the program id to the row of X and Y it should compute.
42 pid = tle.program_id(0)
44 n_offsets = tl.arange(0, TILE_N)
45 mask = n_offsets < N
47 x = tl.load(in_ptr + pid * N + n_offsets, mask, other=0.0).to(tl.float32)
48 m = tl.sum(x) / N
49 d = x - m # deviation
50 s = tl.where(mask, d * d, 0)
51 sum_square = tl.sum(s) # sum of square of deviation
52 var = sum_square / N
53 rstd = tl.math.rsqrt(var + eps)
55 tl.store(out_mean_ptr + pid, m)
56 tl.store(out_rstd_ptr + pid, rstd)
58 if weight_ptr is None:
59 w = 1
60 else:
61 w = tl.load(weight_ptr + n_offsets, mask=mask)
62 if bias_ptr is None:
63 b = 0
64 else:
65 b = tl.load(bias_ptr + n_offsets, mask=mask)
66 out = (x - m) * rstd * w + b
68 tl.store(out_ptr + pid * N + n_offsets, out, mask=mask)
71@libentry()
72@triton.autotune(
73 configs=runtime.get_tuned_config("layer_norm_persistent"),
74 key=["M", "N"],
75)
76@triton.jit(do_not_specialize=["eps"])
77def layer_norm_persistent_kernel_multiline(
78 in_ptr,
79 out_ptr,
80 weight_ptr,
81 bias_ptr,
82 out_mean_ptr, # pointer to the mean
83 out_rstd_ptr, # pointer to the 1/std
84 M,
85 N,
86 eps,
87 TILE_M: tl.constexpr,
88 TILE_N: tl.constexpr,
89):
90 # Map the program id to the row of X and Y it should compute.
91 pid = tle.program_id(0)
92 m_offsets = pid * TILE_M + tl.arange(0, TILE_M)
93 m_mask = m_offsets < M
95 n_offsets = tl.arange(0, TILE_N)[None, :]
96 n_mask = n_offsets < N
97 mask = m_mask[:, None] & n_mask
99 x = tl.load(in_ptr + m_offsets[:, None] * N + n_offsets, mask, other=0.0).to(
100 tl.float32
101 )
102 m = tl.sum(x, axis=1) / N
103 d = x - m[:, None] # deviation
104 s = tl.where(mask, d * d, 0)
105 sum_square = tl.sum(s, axis=1) # sum of square of deviation
106 var = sum_square / N
107 rstd = tl.math.rsqrt(var + eps)
109 tl.store(out_mean_ptr + m_offsets, m, mask=m_mask)
110 tl.store(out_rstd_ptr + m_offsets, rstd, mask=m_mask)
112 if weight_ptr is None:
113 w = 1
114 else:
115 w = tl.load(weight_ptr + n_offsets, mask=n_mask)
116 if bias_ptr is None:
117 b = 0
118 else:
119 b = tl.load(bias_ptr + n_offsets, mask=n_mask)
120 out = (x - m[:, None]) * rstd[:, None] * w + b
122 tl.store(out_ptr + m_offsets[:, None] * N + n_offsets, out, mask=mask)
125@libentry()
126@triton.autotune(
127 configs=runtime.get_tuned_config("layer_norm_loop"),
128 key=["M", "N"],
129)
130@triton.jit(do_not_specialize=["eps"])
131def layer_norm_loop_kernel(
132 in_ptr,
133 out_ptr,
134 weight_ptr,
135 bias_ptr,
136 out_mean_ptr, # pointer to the mean
137 out_rstd_ptr, # pointer to the 1/std
138 M,
139 N,
140 eps,
141 TILE_N: tl.constexpr,
142):
143 # Map the program id to the row of X and Y it should compute.
144 pid = tle.program_id(0)
146 # Compute mean
147 m = tl.zeros((TILE_N,), dtype=tl.float32) # mean
148 s = tl.zeros((TILE_N,), dtype=tl.float32) # sum((x - m)^2)
149 cnt = tl.zeros((TILE_N,), dtype=tl.int32)
150 num_steps = tl.cdiv(N, TILE_N)
151 for step in range(0, num_steps - 1, 1):
152 start_n = step * TILE_N
153 n_offsets = start_n + tl.arange(0, TILE_N)
154 x = tl.load(in_ptr + pid * N + n_offsets).to(tl.float32)
155 new_m = m + (x - m) / (step + 1)
156 new_s = s + (x - new_m) * (x - m)
157 cnt += 1
158 m = new_m
159 s = new_s
161 # the last step
162 for step in range(num_steps - 1, num_steps, 1):
163 start_n = step * TILE_N
164 n_offsets = start_n + tl.arange(0, TILE_N)
165 mask = n_offsets < N
166 x = tl.load(in_ptr + pid * N + n_offsets, mask=mask).to(tl.float32)
167 new_m = tl.where(mask, m + (x - m) / (step + 1), m)
168 new_s = tl.where(mask, s + (x - new_m) * (x - m), s)
169 cnt += mask.to(tl.int32)
170 m = new_m
171 s = new_s
173 final_m = tl.sum(m * cnt) / N
174 var = tl.sum(s + cnt * (m - final_m) * (m - final_m)) / N
175 rstd = tl.math.rsqrt(var + eps)
176 m = final_m
177 # Write mean / rstd
178 tl.store(out_mean_ptr + pid, m)
179 tl.store(out_rstd_ptr + pid, rstd)
181 # reverse the order of the second sweep
182 # Normalize and apply linear transformation
183 prev_multiple = prev_multiple_of(N, TILE_N)
184 # the first step, masking is needed
185 for start_n in range(0, TILE_N, TILE_N):
186 n_offsets = (prev_multiple - start_n) + tl.arange(0, TILE_N)
187 mask = n_offsets < N
188 x = tl.load(
189 in_ptr + pid * N + n_offsets,
190 mask=mask,
191 other=0.0,
192 eviction_policy="evict_first",
193 ).to(tl.float32)
194 if weight_ptr is None:
195 w = 1
196 else:
197 w = tl.load(weight_ptr + n_offsets, mask=mask)
198 if bias_ptr is None:
199 b = 0
200 else:
201 b = tl.load(bias_ptr + n_offsets, mask=mask)
202 out = w * (x - m) * rstd + b
203 tl.store(out_ptr + pid * N + n_offsets, out, mask=mask)
205 for start_n in range(TILE_N, N, TILE_N):
206 n_offsets = (prev_multiple - start_n) + tl.arange(0, TILE_N)
207 x = tl.load(in_ptr + pid * N + n_offsets, eviction_policy="evict_first").to(
208 tl.float32
209 )
210 if weight_ptr is None:
211 w = 1
212 else:
213 w = tl.load(weight_ptr + n_offsets)
214 if bias_ptr is None:
215 b = 0
216 else:
217 b = tl.load(bias_ptr + n_offsets)
218 out = w * (x - m) * rstd + b
219 tl.store(out_ptr + pid * N + n_offsets, out)
222@libentry()
223@triton.autotune(
224 configs=runtime.get_tuned_config("layer_norm_backward"),
225 key=["M", "N"],
226)
227@triton.jit
228def layer_norm_backward_kernel(
229 dY,
230 X,
231 W,
232 Mean,
233 Rstd,
234 dX,
235 M,
236 N,
237 BLOCK_ROW_SIZE: tl.constexpr,
238 BLOCK_COL_SIZE: tl.constexpr,
239):
240 pid = tle.program_id(0) * BLOCK_ROW_SIZE + tl.arange(0, BLOCK_ROW_SIZE)[:, None]
241 row_mask = pid < M
242 dY += pid * N
243 X += pid * N
244 dX += pid * N
245 Mean += pid
246 Rstd += pid
248 mean = tl.load(Mean, mask=row_mask).to(tl.float32)
249 rstd = tl.load(Rstd, mask=row_mask).to(tl.float32)
251 dx_part2 = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
252 dx_part3 = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
254 for off in range(0, N, BLOCK_COL_SIZE):
255 cols = off + tl.arange(0, BLOCK_COL_SIZE)
256 col_mask = cols[None, :] < N
257 mask = row_mask and col_mask
258 dy = tl.load(dY + cols[None, :], mask).to(tl.float32)
259 x = tl.load(X + cols[None, :], mask).to(tl.float32)
260 x = tl.where(mask, x - mean, 0.0)
261 x_hat = x * rstd
262 if W is None:
263 w = 1
264 else:
265 w = tl.load(W + cols, mask=cols < N).to(tl.float32)
266 dx_hat = dy * w
267 dx_part2 += dx_hat
268 dx_part3 += dx_hat * x_hat
270 dx_2 = tl.sum(dx_part2, axis=1)[:, None]
271 dx_3 = tl.sum(dx_part3, axis=1)[:, None]
273 for off in range(0, N, BLOCK_COL_SIZE):
274 cols = off + tl.arange(0, BLOCK_COL_SIZE)
275 col_mask = cols[None, :] < N
276 mask = row_mask and col_mask
277 dy = tl.load(dY + cols[None, :], mask).to(tl.float32)
278 x = tl.load(X + cols[None, :], mask).to(tl.float32)
279 if W is None:
280 w = 1
281 else:
282 w = tl.load(W + cols, mask=cols < N).to(tl.float32)
283 x = tl.where(mask, x - mean, 0.0)
284 x_hat = x * rstd
285 dx_hat = dy * w
286 dx = rstd * (dx_hat - (dx_2 + x_hat * dx_3) / N)
287 tl.store(dX + cols, dx, mask=mask)
290@libentry()
291@triton.autotune(
292 configs=runtime.get_tuned_config("weight_bias_backward"),
293 key=["N"],
294)
295@triton.jit
296def weight_bias_backward_kernel(
297 dY,
298 X,
299 Mean,
300 Rstd,
301 dW,
302 dB,
303 M,
304 N,
305 BLOCK_ROW_SIZE: tl.constexpr,
306 BLOCK_COL_SIZE: tl.constexpr,
307):
308 pid = tle.program_id(0) * BLOCK_COL_SIZE + tl.arange(0, BLOCK_COL_SIZE)
309 col_mask = pid < N
310 dY += pid[None, :]
311 X += pid[None, :]
312 accW = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
313 accB = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
314 for off in range(0, M, BLOCK_ROW_SIZE):
315 rows = off + tl.arange(0, BLOCK_ROW_SIZE)[:, None]
316 row_mask = rows < M
317 mask = row_mask and col_mask[None, :]
318 dy = tl.load(dY + rows * N, mask).to(tl.float32)
319 x = tl.load(X + rows * N, mask).to(tl.float32)
320 mean = tl.load(Mean + rows, mask=rows < M).to(tl.float32)
321 rstd = tl.load(Rstd + rows, mask=rows < M).to(tl.float32)
322 x = tl.where(mask, x - mean, 0.0)
323 accW += dy * x * rstd
324 accB += dy
325 if dW:
326 dw = tl.sum(accW, axis=0)
327 tl.store(dW + pid, dw, mask=col_mask)
328 if dB:
329 db = tl.sum(accB, axis=0)
330 tl.store(dB + pid, db, mask=col_mask)
333def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5):
334 logger.debug("GEMS LAYERNORM FORWARD")
336 N = math.prod(normalized_shape)
337 M = input.numel() // N
339 input = input.contiguous()
340 weight = None if weight is None else weight.contiguous()
341 bias = None if bias is None else bias.contiguous()
342 y = torch.empty_like(input)
344 # NOTE: when the input is half-precision(either float16 or bfloat16)
345 # these statistical data saved for backward is in single precision
346 mean = torch.empty(M, dtype=input.dtype, device=input.device)
347 rstd = torch.empty(M, dtype=input.dtype, device=input.device)
349 with torch_device_fn.device(input.device):
350 if N <= 128:
351 TILE_N = triton.next_power_of_2(N)
352 TILE_M = triton.cdiv(1024, TILE_N)
353 grid = (triton.cdiv(M, TILE_M), 1, 1)
354 layer_norm_persistent_kernel_multiline[grid](
355 input,
356 y,
357 weight,
358 bias,
359 mean,
360 rstd,
361 M,
362 N,
363 eps,
364 TILE_M,
365 TILE_N,
366 )
367 elif N <= 4096:
368 TILE_N = triton.next_power_of_2(N)
369 grid = (M, 1, 1)
370 layer_norm_persistent_kernel[grid](
371 input,
372 y,
373 weight,
374 bias,
375 mean,
376 rstd,
377 M,
378 N,
379 eps,
380 TILE_N,
381 )
382 else:
383 grid = (M, 1, 1)
384 layer_norm_loop_kernel[grid](
385 input,
386 y,
387 weight,
388 bias,
389 mean,
390 rstd,
391 M,
392 N,
393 eps,
394 )
395 return y, mean, rstd
398def layer_norm_backward(
399 grad_out,
400 input,
401 normalized_shape,
402 mean,
403 rstd,
404 weight=None,
405 bias=None,
406 output_mask=None,
407):
408 logger.debug("GEMS LAYERNORM BACKWARD")
410 grad_out = grad_out.contiguous()
411 input = input.contiguous()
412 mean = mean.contiguous()
413 rstd = rstd.contiguous()
414 weight = None if weight is None else weight.contiguous()
415 bias = None if bias is None else bias.contiguous()
417 M = input.shape[0]
418 N = input.numel() // M
420 if output_mask[0]:
421 in_grad = torch.empty_like(input)
422 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_ROW_SIZE"]), 1, 1)
423 with torch_device_fn.device(input.device):
424 layer_norm_backward_kernel[grid](
425 grad_out, input, weight, mean, rstd, in_grad, M, N
426 )
427 else:
428 in_grad = None
430 if output_mask[1] is False and output_mask[2] is False:
431 return in_grad, None, None
433 grid = lambda meta: (triton.cdiv(N, meta["BLOCK_COL_SIZE"]), 1, 1)
434 weight_grad = torch.empty_like(weight) if output_mask[1] else None
435 bias_grad = torch.empty_like(bias) if output_mask[2] else None
436 with torch_device_fn.device(input.device):
437 weight_bias_backward_kernel[grid](
438 grad_out, input, mean, rstd, weight_grad, bias_grad, M, N
439 )
440 return in_grad, weight_grad, bias_grad