Coverage for src/flag_gems/runtime/backend/_cambricon/fused/skip_layernorm.py: 0%

120 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-07 22:33 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10 

11from ..utils import TOTAL_CORE_NUM 

12 

13logger = logging.getLogger(__name__) 

14# When the reduced dimension is greater than MAX_C_MLU_SKIP_LAYERNORM_FORWARD, 

15# it is necessary to split the reduced dimension. 

16MAX_C_MLU_SKIP_LAYERNORM_FORWARD = 8192 

17 

18 

19def cfggen_middle_n(): 

20 block_m = [1, 2, 4, 6, 8, 10] 

21 

22 warps = [1] 

23 num_stages = [1, 3] 

24 configs = [ 

25 triton.Config( 

26 { 

27 "BLOCK_ROW_SIZE": m, 

28 }, 

29 num_warps=w, 

30 num_stages=s, 

31 ) 

32 for m in block_m 

33 for w in warps 

34 for s in num_stages 

35 ] 

36 return configs 

37 

38 

39@libentry() 

40@triton.autotune(configs=cfggen_middle_n(), key=["M", "N"]) 

41@triton.jit(do_not_specialize=["eps"]) 

42def skip_layer_norm_middle_n_kernel( 

43 Y, # pointer to the output 

44 X, # pointer to the input 

45 R, # pointer to the residual 

46 W, # pointer to the weights 

47 B, # pointer to the biases 

48 M, # number of rows in X 

49 eps, # epsilon to avoid division by zero 

50 N: tl.constexpr, # number of columns in X 

51 BLOCK_ROW_SIZE: tl.constexpr, 

52): 

53 pid = tl.program_id(0) 

54 row_start = pid * BLOCK_ROW_SIZE 

55 num_jobs = tl.num_programs(axis=0) 

56 step = num_jobs * BLOCK_ROW_SIZE 

57 

58 cols_n = tl.arange(0, N) 

59 X += cols_n[None, :] 

60 R += cols_n[None, :] 

61 Y += cols_n[None, :] 

62 cols_off = tl.arange(0, N)[None, :] 

63 w = tl.load(W + cols_off) 

64 b = tl.load(B + cols_off) 

65 for row in range(row_start, M, step): 

66 row_off = row + tl.arange(0, BLOCK_ROW_SIZE) 

67 mask = row_off[:, None] < M 

68 off = row_off[:, None] * N 

69 x = tl.load(X + off, mask, other=0.0).to(tl.float32) 

70 r = tl.load(R + off, mask, other=0.0).to(tl.float32) 

71 x += r 

72 

73 # TODO: Use the following code as a fallback once the optimization for trans is complete. 

74 # mean = tl.sum(x_v, axis=1) / N 

75 # var = tl.sum(x_v * x_v, axis=1) / N - (mean * mean) 

76 # mean_bc = mean[:, None] 

77 

78 x_v = tl.view(x, (BLOCK_ROW_SIZE, N)) 

79 x_trans = tl.trans(x_v) 

80 mean = tl.sum(x_trans, axis=0) / N 

81 mean_bc = mean[:, None] 

82 var = tl.sum(x_trans * x_trans, axis=0) / N - (mean * mean) 

83 var = var[:, None] 

84 rstd = 1 / tl.sqrt(var + eps) 

85 x = x - mean_bc 

86 x_hat = x * rstd 

87 y = x_hat * w + b 

88 tl.store(Y + off, y, mask=mask) 

89 

90 

91def cfggen(): 

92 block_m = [i for i in range(1, 36, 4)] # [1, 2, 4] 

93 block_n = [i for i in range(64, 193, 64)] 

94 warps = [1] 

95 num_stages = [1, 3] 

96 configs = [ 

97 triton.Config( 

98 {"BLOCK_ROW_SIZE": m, "BLOCK_COL_SIZE": n}, num_warps=w, num_stages=s 

99 ) 

100 for m in block_m 

101 for n in block_n 

102 for w in warps 

103 for s in num_stages 

104 ] 

105 return configs 

106 

107 

108@libentry() 

109@triton.autotune(configs=cfggen(), key=["M", "N"]) 

110@triton.jit(do_not_specialize=["eps"]) 

111def skip_layer_norm_kernel( 

112 Y, # pointer to the output 

113 X, # pointer to the input 

114 R, # pointer to the residual 

115 W, # pointer to the weights 

116 B, # pointer to the biases 

117 M, # number of rows in X 

118 eps, # epsilon to avoid division by zero 

119 N: tl.constexpr, # number of columns in X 

120 BLOCK_ROW_SIZE: tl.constexpr, 

121 BLOCK_COL_SIZE: tl.constexpr, 

122): 

123 pid = tl.program_id(0) 

124 row = pid * BLOCK_ROW_SIZE + tl.arange(0, BLOCK_ROW_SIZE)[:, None] 

125 row_mask = row < M 

126 Y += row * N 

127 X += row * N 

128 R += row * N 

129 

130 # Compute mean 

131 _mean = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

132 # Compute variance 

133 _var = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

134 for off in range(0, N, BLOCK_COL_SIZE): 

135 cols = off + tl.arange(0, BLOCK_COL_SIZE)[None, :] 

136 col_mask = cols < N 

137 mask = row_mask and col_mask 

138 

139 x = tl.load(X + cols, mask, other=0.0).to(tl.float32) 

140 r = tl.load(R + cols, mask, other=0.0).to(tl.float32) 

141 x += r 

142 _mean += x 

143 _var += x * x 

144 trans_mean = tl.trans(_mean) 

145 mean = tl.sum(trans_mean, axis=0) / N 

146 mean_bc = mean[:, None] 

147 trans_var = tl.trans(_var) 

148 var = tl.sum(trans_var, axis=0) / N - (mean * mean) 

149 var = var[:, None] 

150 rstd = 1 / tl.sqrt(var + eps) 

151 

152 # Normalize and apply linear transformation 

153 for off in range(0, N, BLOCK_COL_SIZE): 

154 cols = off + tl.arange(0, BLOCK_COL_SIZE)[None, :] 

155 col_mask = cols < N 

156 mask = row_mask and col_mask 

157 

158 w = tl.load(W + cols, col_mask) 

159 b = tl.load(B + cols, col_mask) 

160 x = tl.load(X + cols, mask, other=0.0).to(tl.float32) 

161 r = tl.load(R + cols, mask, other=0.0).to(tl.float32) 

162 x += r 

163 x = tl.where(col_mask, x - mean_bc, 0.0) 

164 x_hat = x * rstd 

165 y = x_hat * w + b 

166 # Write output 

167 tl.store(Y + cols, y, mask=mask) 

168 

169 

170class SkipLayerNorm(torch.autograd.Function): 

171 @staticmethod 

172 def forward(ctx, x, residual, normalized_shape, weight, bias, eps=1e-5): 

173 logger.debug("GEMS_CAMBRICON SKIP LAYERNORM FORWARD") 

174 dim = x.ndim - len(normalized_shape) 

175 M = math.prod(x.shape[:dim]) 

176 N = math.prod(normalized_shape) 

177 

178 grid = lambda META: (triton.cdiv(M, META["BLOCK_ROW_SIZE"]),) 

179 x = x.contiguous() 

180 residual = residual.contiguous() 

181 weight = weight.contiguous() 

182 bias = bias.contiguous() 

183 y = torch.empty_like(x) 

184 

185 if N < MAX_C_MLU_SKIP_LAYERNORM_FORWARD: 

186 grid = lambda META: ( 

187 min(triton.cdiv(M, META["BLOCK_ROW_SIZE"]), TOTAL_CORE_NUM), 

188 ) 

189 with torch.cuda.device(x.device): 

190 skip_layer_norm_middle_n_kernel[grid]( 

191 y, x, residual, weight, bias, M, eps, N 

192 ) 

193 else: 

194 grid = lambda META: (triton.cdiv(M, META["BLOCK_ROW_SIZE"]),) 

195 with torch_device_fn.device(x.device): 

196 skip_layer_norm_kernel[grid](y, x, residual, weight, bias, M, eps, N) 

197 return y 

198 

199 

200def skip_layer_norm(x, residual, normalized_shape, weight, bias, eps=1e-5): 

201 return SkipLayerNorm.apply(x, residual, normalized_shape, weight, bias, eps)