Coverage for src/flag_gems/runtime/backend/_kunlunxin/fused/skip_layernorm.py: 0%

93 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-13 10:08 +0800

1import builtins 

2import logging 

3import math 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

14 

15 

16@libentry() 

17@triton.jit(do_not_specialize=["eps"]) 

18def skip_layer_norm_kernel( 

19 Y, # pointer to the output 

20 X, # pointer to the input 

21 R, # pointer to the residual 

22 W, # pointer to the weights 

23 B, # pointer to the biases 

24 y_stride_r, 

25 y_stride_c, 

26 x_stride_r, # how much to increase the pointer when moving by 1 row 

27 x_stride_c, # how much to increase the pointer when moving by 1 col 

28 r_stride_r, # how much to increase the pointer when moving by 1 row 

29 r_stride_c, # how much to increase the pointer when moving by 1 col 

30 N, # number of columns in X 

31 eps, # epsilon to avoid division by zero 

32 BLOCK_SIZE: tl.constexpr, 

33): 

34 pid = tle.program_id(0) 

35 Y += pid * y_stride_r 

36 X += pid * x_stride_r 

37 R += pid * r_stride_r 

38 

39 mask = tl.arange(0, BLOCK_SIZE) < N 

40 cols = tl.arange(0, BLOCK_SIZE) 

41 x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32) 

42 r = tl.load(R + cols * r_stride_c, mask, other=0.0).to(tl.float32) 

43 

44 x += r 

45 

46 mean = tl.sum(x, axis=0) / N 

47 

48 # Compute variance 

49 _var = tl.where(mask, x - mean, 0.0) 

50 _var = _var * _var 

51 var = tl.sum(_var, axis=0) / N 

52 rstd = 1 / tl.sqrt(var + eps) 

53 

54 w = tl.load(W + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0).to(tl.float32) 

55 b = tl.load(B + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0).to(tl.float32) 

56 

57 x_hat = (x - mean) * rstd 

58 y = w * x_hat + b 

59 y = y.to(Y.dtype.element_ty) 

60 tl.store(Y + cols * y_stride_c, y, mask=mask) 

61 

62 

63@libentry() 

64@triton.jit(do_not_specialize=["eps"]) 

65def skip_layer_norm_kernel_tile( 

66 Y, # pointer to the output 

67 X, # pointer to the input 

68 R, # pointer to the residual 

69 W, # pointer to the weights 

70 B, # pointer to the biases 

71 y_stride_r, 

72 y_stride_c, 

73 x_stride_r, # how much to increase the pointer when moving by 1 row 

74 x_stride_c, # how much to increase the pointer when moving by 1 col 

75 r_stride_r, # how much to increase the pointer when moving by 1 row 

76 r_stride_c, # how much to increase the pointer when moving by 1 col 

77 N: tl.constexpr, # number of columns in X 

78 eps, # epsilon to avoid division by zero 

79 BLOCK_SIZE: tl.constexpr, 

80): 

81 pid = tl.program_id(0) 

82 Y += pid * y_stride_r 

83 X += pid * x_stride_r 

84 R += pid * r_stride_r 

85 

86 # mask = tl.arange(0, BLOCK_SIZE) < N 

87 # cols = tl.arange(0, BLOCK_SIZE) 

88 # x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32) 

89 # r = tl.load(R + cols * r_stride_c, mask, other=0.0).to(tl.float32) 

90 

91 # x += r 

92 

93 # mean = tl.sum(x, axis=0) / N 

94 _sum = tl.zeros([BLOCK_SIZE], dtype=tl.float32) 

95 for off in range(0, N, BLOCK_SIZE): 

96 cols = off + tl.arange(0, BLOCK_SIZE) 

97 mask = cols < N 

98 x = tl.load(X + cols, mask, other=0.0).to(tl.float32) 

99 r = tl.load(R + cols, mask, other=0.0).to(tl.float32) 

100 x += r 

101 _sum += x 

102 

103 mean = tl.sum(_sum) / N 

104 

105 # Compute variance 

106 # _var = tl.where(mask, x - mean, 0.0) 

107 # _var = _var * _var 

108 _var_base = tl.zeros([BLOCK_SIZE], dtype=tl.float32) 

109 for off in range(0, N, BLOCK_SIZE): 

110 cols = off + tl.arange(0, BLOCK_SIZE) 

111 mask = cols < N 

112 x = tl.load(X + cols, mask, other=0.0).to(tl.float32) 

113 r = tl.load(R + cols, mask, other=0.0).to(tl.float32) 

114 x += r 

115 _var = tl.where(mask, x - mean, 0.0) 

116 _var = _var * _var 

117 _var_base += _var 

118 

119 var = tl.sum(_var_base, axis=0) / N 

120 rstd = 1 / tl.sqrt(var + eps) 

121 

122 for off in range(0, N, BLOCK_SIZE): 

123 cols = off + tl.arange(0, BLOCK_SIZE) 

124 mask = cols < N 

125 x = tl.load(X + cols, mask, other=0.0).to(tl.float32) 

126 r = tl.load(R + cols, mask, other=0.0).to(tl.float32) 

127 x += r 

128 w = tl.load(W + cols, mask, other=0.0).to(tl.float32) 

129 b = tl.load(B + cols, mask, other=0.0).to(tl.float32) 

130 x_hat = (x - mean) * rstd 

131 y = w * x_hat + b 

132 y = y.to(Y.dtype.element_ty) 

133 tl.store(Y + cols * y_stride_c, y, mask=mask) 

134 

135 

136class SkipLayerNorm(torch.autograd.Function): 

137 @staticmethod 

138 def forward(ctx, x, residual, normalized_shape, weight, bias, eps=1e-5): 

139 logger.debug("GEMS SKIP LAYERNORM FORWARD") 

140 dim = x.ndim - len(normalized_shape) 

141 M = math.prod(x.shape[:dim]) 

142 N = math.prod(normalized_shape) 

143 

144 BLOCK_SIZE = builtins.min( 

145 64 * 64, triton.next_power_of_2(N) 

146 ) # core_num * buffer_size_limit 

147 x = x.contiguous() 

148 residual = residual.contiguous() 

149 weight = weight.contiguous() 

150 bias = bias.contiguous() 

151 y = torch.empty_like(x) 

152 

153 with torch.cuda.device(x.device): 

154 if N > 64 * 64: 

155 skip_layer_norm_kernel_tile[M,]( 

156 y, 

157 x, 

158 residual, 

159 weight, 

160 bias, 

161 N, 

162 1, 

163 N, 

164 1, 

165 N, 

166 1, 

167 N, 

168 eps, 

169 BLOCK_SIZE, 

170 isCloseUnrollControl=True, 

171 ) 

172 else: 

173 skip_layer_norm_kernel[M,]( 

174 y, x, residual, weight, bias, N, 1, N, 1, N, 1, N, eps, BLOCK_SIZE 

175 ) 

176 return y 

177 

178 BLOCK_SIZE = triton.next_power_of_2(N) 

179 x = x.contiguous() 

180 residual = residual.contiguous() 

181 weight = weight.contiguous() 

182 bias = bias.contiguous() 

183 y = torch.empty_like(x) 

184 

185 with torch_device_fn.device(x.device): 

186 skip_layer_norm_kernel[M,]( 

187 y, 

188 x, 

189 residual, 

190 weight, 

191 bias, 

192 N, 

193 1, 

194 N, 

195 1, 

196 N, 

197 1, 

198 N, 

199 eps, 

200 BLOCK_SIZE, 

201 isCloseUnrollControl=True, 

202 ) 

203 return y 

204 

205 

206def skip_layer_norm(x, residual, normalized_shape, weight, bias, eps=1e-5): 

207 return SkipLayerNorm.apply(x, residual, normalized_shape, weight, bias, eps)