Coverage for src/flag_gems/runtime/backend/_ascend/fused/skip_layernorm.py: 0%
60 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger(__name__)
15@libentry()
16@triton.jit(do_not_specialize=["eps"])
17def skip_layer_norm_kernel(
18 Y, # pointer to the output
19 X, # pointer to the input
20 R, # pointer to the residual
21 W, # pointer to the weights
22 B, # pointer to the biases
23 y_stride_r,
24 y_stride_c,
25 x_stride_r, # pointer increment for moving 1 row
26 x_stride_c, # pointer increment for moving 1 column
27 r_stride_r, # pointer increment for moving 1 row in residual
28 r_stride_c, # pointer increment for moving 1 column in residual
29 N, # number of columns in X
30 eps, # epsilon to avoid division by zero
31 BLOCK_SIZE: tl.constexpr,
32):
33 pid = tle.program_id(0)
34 loops = tl.cdiv(N, BLOCK_SIZE)
36 # Initialize accumulators for sum of x and sum of squared x (for mean/variance)
37 sum_x = tl.zeros((), dtype=tl.float32) # Explicitly specify as float32
38 sum_sq = tl.zeros((), dtype=tl.float32) # For variance calculation
40 # Pointer offsets for current row (based on program ID)
41 X += pid * x_stride_r
42 R += pid * r_stride_r
43 Y += pid * y_stride_r
45 # This partitioning is special: need to load entire row data for computation, so N is required
46 # First pass: compute sum(x) and sum(x²) in one traversal (optimized from 2 passes)
47 for process in range(loops):
48 cols = process * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
49 mask = cols < N
50 # Load input and residual, compute x = X + residual
51 x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32)
52 r = tl.load(R + cols * r_stride_c, mask, other=0.0).to(tl.float32)
53 x += r
54 # Accumulate sum of x and sum of x squared
55 sum_x += tl.sum(x, axis=0)
56 sum_sq += tl.sum(x * x, axis=0)
58 # Compute mean and variance from accumulated sums
59 mean = sum_x / N
60 var = (sum_sq / N) - (mean * mean) # Equivalent to E[x²] - (E[x])²
61 rstd = 1 / tl.sqrt(var + eps) # Reciprocal of standard deviation
63 # Second pass: compute final output (y = w * (x-mean)/std + b)
64 for process in range(loops):
65 cols = process * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
66 mask = cols < N
67 # Load weights and biases
68 w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32)
69 b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
70 # Re-load x (X + residual) for normalization
71 x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32)
72 r = tl.load(R + cols * r_stride_c, mask, other=0.0).to(tl.float32)
73 x += r
74 # Apply layer norm and linear transformation
75 x_hat = (x - mean) * rstd
76 y = w * x_hat + b
77 # Cast to output dtype and store
78 y = y.to(Y.dtype.element_ty)
79 tl.store(Y + cols * y_stride_c, y, mask=mask)
82class SkipLayerNorm(torch.autograd.Function):
83 @staticmethod
84 def forward(ctx, x, residual, normalized_shape, weight, bias, eps=1e-5):
85 logger.debug("GEMS_ASCEND SKIP LAYERNORM FORWARD")
86 dim = x.ndim - len(normalized_shape)
87 M = min(math.prod(x.shape[:dim]), 65535)
88 N = math.prod(normalized_shape)
89 BLOCK_SIZE = min(triton.next_power_of_2(N), 4096)
90 x = x.contiguous()
91 residual = residual.contiguous()
92 weight = weight.contiguous()
93 bias = bias.contiguous()
94 y = torch.empty_like(x)
95 with torch_device_fn.device(x.device):
96 skip_layer_norm_kernel[M,](
97 y, x, residual, weight, bias, N, 1, N, 1, N, 1, N, eps, BLOCK_SIZE
98 )
99 return y
102def skip_layer_norm(x, residual, normalized_shape, weight, bias, eps=1e-5):
103 return SkipLayerNorm.apply(x, residual, normalized_shape, weight, bias, eps)