Coverage for src/flag_gems/fused/skip_layernorm.py: 60%
50 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger(__name__)
15@libentry()
16@triton.jit(do_not_specialize=["eps"])
17def skip_layer_norm_kernel(
18 Y, # pointer to the output
19 X, # pointer to the input
20 R, # pointer to the residual
21 W, # pointer to the weights
22 B, # pointer to the biases
23 y_stride_r,
24 y_stride_c,
25 x_stride_r, # how much to increase the pointer when moving by 1 row
26 x_stride_c, # how much to increase the pointer when moving by 1 col
27 r_stride_r, # how much to increase the pointer when moving by 1 row
28 r_stride_c, # how much to increase the pointer when moving by 1 col
29 N, # number of columns in X
30 eps, # epsilon to avoid division by zero
31 BLOCK_SIZE: tl.constexpr,
32):
33 pid = tle.program_id(0)
34 Y += pid * y_stride_r
35 X += pid * x_stride_r
36 R += pid * r_stride_r
38 mask = tl.arange(0, BLOCK_SIZE) < N
39 cols = tl.arange(0, BLOCK_SIZE)
40 x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32)
41 r = tl.load(R + cols * r_stride_c, mask, other=0.0).to(tl.float32)
43 x += r
45 mean = tl.sum(x, axis=0) / N
47 # Compute variance
48 _var = tl.where(mask, x - mean, 0.0)
49 _var = _var * _var
50 var = tl.sum(_var, axis=0) / N
51 rstd = 1 / tl.sqrt(var + eps)
53 w = tl.load(W + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0).to(tl.float32)
54 b = tl.load(B + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0).to(tl.float32)
56 x_hat = (x - mean) * rstd
57 y = w * x_hat + b
58 y = y.to(Y.dtype.element_ty)
59 tl.store(Y + cols * y_stride_c, y, mask=mask)
62class SkipLayerNorm(torch.autograd.Function):
63 @staticmethod
64 def forward(ctx, x, residual, normalized_shape, weight, bias, eps=1e-5):
65 logger.debug("GEMS SKIP LAYERNORM FORWARD")
66 dim = x.ndim - len(normalized_shape)
67 M = math.prod(x.shape[:dim])
68 N = math.prod(normalized_shape)
70 BLOCK_SIZE = triton.next_power_of_2(N)
71 x = x.contiguous()
72 residual = residual.contiguous()
73 weight = weight.contiguous()
74 bias = bias.contiguous()
75 y = torch.empty_like(x)
77 with torch_device_fn.device(x.device):
78 skip_layer_norm_kernel[M,](
79 y, x, residual, weight, bias, N, 1, N, 1, N, 1, N, eps, BLOCK_SIZE
80 )
81 return y
84def skip_layer_norm(x, residual, normalized_shape, weight, bias, eps=1e-5):
85 return SkipLayerNorm.apply(x, residual, normalized_shape, weight, bias, eps)