Coverage for src/flag_gems/runtime/backend/_cambricon/fused/skip_layernorm.py: 0%
120 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
11from ..utils import TOTAL_CORE_NUM
13logger = logging.getLogger(__name__)
14# When the reduced dimension is greater than MAX_C_MLU_SKIP_LAYERNORM_FORWARD,
15# it is necessary to split the reduced dimension.
16MAX_C_MLU_SKIP_LAYERNORM_FORWARD = 8192
19def cfggen_middle_n():
20 block_m = [1, 2, 4, 6, 8, 10]
22 warps = [1]
23 num_stages = [1, 3]
24 configs = [
25 triton.Config(
26 {
27 "BLOCK_ROW_SIZE": m,
28 },
29 num_warps=w,
30 num_stages=s,
31 )
32 for m in block_m
33 for w in warps
34 for s in num_stages
35 ]
36 return configs
39@libentry()
40@triton.autotune(configs=cfggen_middle_n(), key=["M", "N"])
41@triton.jit(do_not_specialize=["eps"])
42def skip_layer_norm_middle_n_kernel(
43 Y, # pointer to the output
44 X, # pointer to the input
45 R, # pointer to the residual
46 W, # pointer to the weights
47 B, # pointer to the biases
48 M, # number of rows in X
49 eps, # epsilon to avoid division by zero
50 N: tl.constexpr, # number of columns in X
51 BLOCK_ROW_SIZE: tl.constexpr,
52):
53 pid = tl.program_id(0)
54 row_start = pid * BLOCK_ROW_SIZE
55 num_jobs = tl.num_programs(axis=0)
56 step = num_jobs * BLOCK_ROW_SIZE
58 cols_n = tl.arange(0, N)
59 X += cols_n[None, :]
60 R += cols_n[None, :]
61 Y += cols_n[None, :]
62 cols_off = tl.arange(0, N)[None, :]
63 w = tl.load(W + cols_off)
64 b = tl.load(B + cols_off)
65 for row in range(row_start, M, step):
66 row_off = row + tl.arange(0, BLOCK_ROW_SIZE)
67 mask = row_off[:, None] < M
68 off = row_off[:, None] * N
69 x = tl.load(X + off, mask, other=0.0).to(tl.float32)
70 r = tl.load(R + off, mask, other=0.0).to(tl.float32)
71 x += r
73 # TODO: Use the following code as a fallback once the optimization for trans is complete.
74 # mean = tl.sum(x_v, axis=1) / N
75 # var = tl.sum(x_v * x_v, axis=1) / N - (mean * mean)
76 # mean_bc = mean[:, None]
78 x_v = tl.view(x, (BLOCK_ROW_SIZE, N))
79 x_trans = tl.trans(x_v)
80 mean = tl.sum(x_trans, axis=0) / N
81 mean_bc = mean[:, None]
82 var = tl.sum(x_trans * x_trans, axis=0) / N - (mean * mean)
83 var = var[:, None]
84 rstd = 1 / tl.sqrt(var + eps)
85 x = x - mean_bc
86 x_hat = x * rstd
87 y = x_hat * w + b
88 tl.store(Y + off, y, mask=mask)
91def cfggen():
92 block_m = [i for i in range(1, 36, 4)] # [1, 2, 4]
93 block_n = [i for i in range(64, 193, 64)]
94 warps = [1]
95 num_stages = [1, 3]
96 configs = [
97 triton.Config(
98 {"BLOCK_ROW_SIZE": m, "BLOCK_COL_SIZE": n}, num_warps=w, num_stages=s
99 )
100 for m in block_m
101 for n in block_n
102 for w in warps
103 for s in num_stages
104 ]
105 return configs
108@libentry()
109@triton.autotune(configs=cfggen(), key=["M", "N"])
110@triton.jit(do_not_specialize=["eps"])
111def skip_layer_norm_kernel(
112 Y, # pointer to the output
113 X, # pointer to the input
114 R, # pointer to the residual
115 W, # pointer to the weights
116 B, # pointer to the biases
117 M, # number of rows in X
118 eps, # epsilon to avoid division by zero
119 N: tl.constexpr, # number of columns in X
120 BLOCK_ROW_SIZE: tl.constexpr,
121 BLOCK_COL_SIZE: tl.constexpr,
122):
123 pid = tl.program_id(0)
124 row = pid * BLOCK_ROW_SIZE + tl.arange(0, BLOCK_ROW_SIZE)[:, None]
125 row_mask = row < M
126 Y += row * N
127 X += row * N
128 R += row * N
130 # Compute mean
131 _mean = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
132 # Compute variance
133 _var = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
134 for off in range(0, N, BLOCK_COL_SIZE):
135 cols = off + tl.arange(0, BLOCK_COL_SIZE)[None, :]
136 col_mask = cols < N
137 mask = row_mask and col_mask
139 x = tl.load(X + cols, mask, other=0.0).to(tl.float32)
140 r = tl.load(R + cols, mask, other=0.0).to(tl.float32)
141 x += r
142 _mean += x
143 _var += x * x
144 trans_mean = tl.trans(_mean)
145 mean = tl.sum(trans_mean, axis=0) / N
146 mean_bc = mean[:, None]
147 trans_var = tl.trans(_var)
148 var = tl.sum(trans_var, axis=0) / N - (mean * mean)
149 var = var[:, None]
150 rstd = 1 / tl.sqrt(var + eps)
152 # Normalize and apply linear transformation
153 for off in range(0, N, BLOCK_COL_SIZE):
154 cols = off + tl.arange(0, BLOCK_COL_SIZE)[None, :]
155 col_mask = cols < N
156 mask = row_mask and col_mask
158 w = tl.load(W + cols, col_mask)
159 b = tl.load(B + cols, col_mask)
160 x = tl.load(X + cols, mask, other=0.0).to(tl.float32)
161 r = tl.load(R + cols, mask, other=0.0).to(tl.float32)
162 x += r
163 x = tl.where(col_mask, x - mean_bc, 0.0)
164 x_hat = x * rstd
165 y = x_hat * w + b
166 # Write output
167 tl.store(Y + cols, y, mask=mask)
170class SkipLayerNorm(torch.autograd.Function):
171 @staticmethod
172 def forward(ctx, x, residual, normalized_shape, weight, bias, eps=1e-5):
173 logger.debug("GEMS_CAMBRICON SKIP LAYERNORM FORWARD")
174 dim = x.ndim - len(normalized_shape)
175 M = math.prod(x.shape[:dim])
176 N = math.prod(normalized_shape)
178 grid = lambda META: (triton.cdiv(M, META["BLOCK_ROW_SIZE"]),)
179 x = x.contiguous()
180 residual = residual.contiguous()
181 weight = weight.contiguous()
182 bias = bias.contiguous()
183 y = torch.empty_like(x)
185 if N < MAX_C_MLU_SKIP_LAYERNORM_FORWARD:
186 grid = lambda META: (
187 min(triton.cdiv(M, META["BLOCK_ROW_SIZE"]), TOTAL_CORE_NUM),
188 )
189 with torch.cuda.device(x.device):
190 skip_layer_norm_middle_n_kernel[grid](
191 y, x, residual, weight, bias, M, eps, N
192 )
193 else:
194 grid = lambda META: (triton.cdiv(M, META["BLOCK_ROW_SIZE"]),)
195 with torch_device_fn.device(x.device):
196 skip_layer_norm_kernel[grid](y, x, residual, weight, bias, M, eps, N)
197 return y
200def skip_layer_norm(x, residual, normalized_shape, weight, bias, eps=1e-5):
201 return SkipLayerNorm.apply(x, residual, normalized_shape, weight, bias, eps)