Coverage for src/flag_gems/runtime/backend/_kunlunxin/fused/skip_layernorm.py: 0%
93 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1import builtins
2import logging
3import math
5import torch
6import triton
7import triton.language as tl
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as tle
13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
16@libentry()
17@triton.jit(do_not_specialize=["eps"])
18def skip_layer_norm_kernel(
19 Y, # pointer to the output
20 X, # pointer to the input
21 R, # pointer to the residual
22 W, # pointer to the weights
23 B, # pointer to the biases
24 y_stride_r,
25 y_stride_c,
26 x_stride_r, # how much to increase the pointer when moving by 1 row
27 x_stride_c, # how much to increase the pointer when moving by 1 col
28 r_stride_r, # how much to increase the pointer when moving by 1 row
29 r_stride_c, # how much to increase the pointer when moving by 1 col
30 N, # number of columns in X
31 eps, # epsilon to avoid division by zero
32 BLOCK_SIZE: tl.constexpr,
33):
34 pid = tle.program_id(0)
35 Y += pid * y_stride_r
36 X += pid * x_stride_r
37 R += pid * r_stride_r
39 mask = tl.arange(0, BLOCK_SIZE) < N
40 cols = tl.arange(0, BLOCK_SIZE)
41 x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32)
42 r = tl.load(R + cols * r_stride_c, mask, other=0.0).to(tl.float32)
44 x += r
46 mean = tl.sum(x, axis=0) / N
48 # Compute variance
49 _var = tl.where(mask, x - mean, 0.0)
50 _var = _var * _var
51 var = tl.sum(_var, axis=0) / N
52 rstd = 1 / tl.sqrt(var + eps)
54 w = tl.load(W + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0).to(tl.float32)
55 b = tl.load(B + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0).to(tl.float32)
57 x_hat = (x - mean) * rstd
58 y = w * x_hat + b
59 y = y.to(Y.dtype.element_ty)
60 tl.store(Y + cols * y_stride_c, y, mask=mask)
63@libentry()
64@triton.jit(do_not_specialize=["eps"])
65def skip_layer_norm_kernel_tile(
66 Y, # pointer to the output
67 X, # pointer to the input
68 R, # pointer to the residual
69 W, # pointer to the weights
70 B, # pointer to the biases
71 y_stride_r,
72 y_stride_c,
73 x_stride_r, # how much to increase the pointer when moving by 1 row
74 x_stride_c, # how much to increase the pointer when moving by 1 col
75 r_stride_r, # how much to increase the pointer when moving by 1 row
76 r_stride_c, # how much to increase the pointer when moving by 1 col
77 N: tl.constexpr, # number of columns in X
78 eps, # epsilon to avoid division by zero
79 BLOCK_SIZE: tl.constexpr,
80):
81 pid = tl.program_id(0)
82 Y += pid * y_stride_r
83 X += pid * x_stride_r
84 R += pid * r_stride_r
86 # mask = tl.arange(0, BLOCK_SIZE) < N
87 # cols = tl.arange(0, BLOCK_SIZE)
88 # x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32)
89 # r = tl.load(R + cols * r_stride_c, mask, other=0.0).to(tl.float32)
91 # x += r
93 # mean = tl.sum(x, axis=0) / N
94 _sum = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
95 for off in range(0, N, BLOCK_SIZE):
96 cols = off + tl.arange(0, BLOCK_SIZE)
97 mask = cols < N
98 x = tl.load(X + cols, mask, other=0.0).to(tl.float32)
99 r = tl.load(R + cols, mask, other=0.0).to(tl.float32)
100 x += r
101 _sum += x
103 mean = tl.sum(_sum) / N
105 # Compute variance
106 # _var = tl.where(mask, x - mean, 0.0)
107 # _var = _var * _var
108 _var_base = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
109 for off in range(0, N, BLOCK_SIZE):
110 cols = off + tl.arange(0, BLOCK_SIZE)
111 mask = cols < N
112 x = tl.load(X + cols, mask, other=0.0).to(tl.float32)
113 r = tl.load(R + cols, mask, other=0.0).to(tl.float32)
114 x += r
115 _var = tl.where(mask, x - mean, 0.0)
116 _var = _var * _var
117 _var_base += _var
119 var = tl.sum(_var_base, axis=0) / N
120 rstd = 1 / tl.sqrt(var + eps)
122 for off in range(0, N, BLOCK_SIZE):
123 cols = off + tl.arange(0, BLOCK_SIZE)
124 mask = cols < N
125 x = tl.load(X + cols, mask, other=0.0).to(tl.float32)
126 r = tl.load(R + cols, mask, other=0.0).to(tl.float32)
127 x += r
128 w = tl.load(W + cols, mask, other=0.0).to(tl.float32)
129 b = tl.load(B + cols, mask, other=0.0).to(tl.float32)
130 x_hat = (x - mean) * rstd
131 y = w * x_hat + b
132 y = y.to(Y.dtype.element_ty)
133 tl.store(Y + cols * y_stride_c, y, mask=mask)
136class SkipLayerNorm(torch.autograd.Function):
137 @staticmethod
138 def forward(ctx, x, residual, normalized_shape, weight, bias, eps=1e-5):
139 logger.debug("GEMS SKIP LAYERNORM FORWARD")
140 dim = x.ndim - len(normalized_shape)
141 M = math.prod(x.shape[:dim])
142 N = math.prod(normalized_shape)
144 BLOCK_SIZE = builtins.min(
145 64 * 64, triton.next_power_of_2(N)
146 ) # core_num * buffer_size_limit
147 x = x.contiguous()
148 residual = residual.contiguous()
149 weight = weight.contiguous()
150 bias = bias.contiguous()
151 y = torch.empty_like(x)
153 with torch.cuda.device(x.device):
154 if N > 64 * 64:
155 skip_layer_norm_kernel_tile[M,](
156 y,
157 x,
158 residual,
159 weight,
160 bias,
161 N,
162 1,
163 N,
164 1,
165 N,
166 1,
167 N,
168 eps,
169 BLOCK_SIZE,
170 isCloseUnrollControl=True,
171 )
172 else:
173 skip_layer_norm_kernel[M,](
174 y, x, residual, weight, bias, N, 1, N, 1, N, 1, N, eps, BLOCK_SIZE
175 )
176 return y
178 BLOCK_SIZE = triton.next_power_of_2(N)
179 x = x.contiguous()
180 residual = residual.contiguous()
181 weight = weight.contiguous()
182 bias = bias.contiguous()
183 y = torch.empty_like(x)
185 with torch_device_fn.device(x.device):
186 skip_layer_norm_kernel[M,](
187 y,
188 x,
189 residual,
190 weight,
191 bias,
192 N,
193 1,
194 N,
195 1,
196 N,
197 1,
198 N,
199 eps,
200 BLOCK_SIZE,
201 isCloseUnrollControl=True,
202 )
203 return y
206def skip_layer_norm(x, residual, normalized_shape, weight, bias, eps=1e-5):
207 return SkipLayerNorm.apply(x, residual, normalized_shape, weight, bias, eps)