Coverage for src/flag_gems/runtime/backend/_ascend/fla/layernorm_guard.py: 0%

96 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/modules/layernorm_gated.py 

2# Copyright (c) 2024, Tri Dao. 

3# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html 

4# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. 

5# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling. 

6# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. 

7# mypy: ignore-errors 

8 

9import torch 

10import triton 

11import triton.language as tl 

12 

13MAX_CORES = 65535 

14 

15 

16@triton.heuristics( 

17 { 

18 "HAS_BIAS": lambda args: args["B"] is not None, 

19 "HAS_Z": lambda args: args["Z"] is not None, 

20 } 

21) 

22@triton.jit 

23def layer_norm_fwd_kernel( 

24 X, # pointer to the input 

25 Y, # pointer to the output 

26 W, # pointer to the weights 

27 B, # pointer to the biases 

28 Z, # pointer to the other branch 

29 Mean, # pointer to the mean 

30 Rstd, # pointer to the 1/std 

31 stride_x_row, # how much to increase the pointer when moving by 1 row 

32 stride_y_row, 

33 stride_z_row, 

34 M, # number of rows in X_base 

35 N, # number of columns in X_base 

36 eps, # epsilon to avoid division by zero 

37 BLOCK_N: tl.constexpr, 

38 HAS_BIAS: tl.constexpr, 

39 HAS_Z: tl.constexpr, 

40 NORM_BEFORE_GATE: tl.constexpr, 

41 IS_RMS_NORM: tl.constexpr, 

42 N_CORES: tl.constexpr, 

43): 

44 # Map the program id to the row of X_base and Y_base it should compute. 

45 row = tl.program_id(0) 

46 group = tl.program_id(1) 

47 

48 BLOCK_ROWS = M if M < N_CORES else N_CORES 

49 n_iters = M // BLOCK_ROWS 

50 remain = M % BLOCK_ROWS 

51 if row < remain: 

52 n_iters = n_iters + 1 

53 

54 for i in tl.range(n_iters): 

55 X_base = X + (i * BLOCK_ROWS * stride_x_row) + row * stride_x_row + group * N 

56 Y_base = Y + (i * BLOCK_ROWS * stride_y_row) + row * stride_y_row + group * N 

57 if HAS_Z: 

58 Z_base = ( 

59 Z + (i * BLOCK_ROWS * stride_z_row) + row * stride_z_row + group * N 

60 ) 

61 if not IS_RMS_NORM: 

62 Mean_base = Mean + (i * BLOCK_ROWS) + group * M 

63 Rstd_base = Rstd + (i * BLOCK_ROWS) + group * M 

64 W_base = W + group * N 

65 if HAS_BIAS: 

66 B_base = B + group * N 

67 # Compute mean and variance 

68 cols = tl.arange(0, BLOCK_N) 

69 x = tl.load(X_base + cols, mask=cols < N, other=0.0).to(tl.float32) 

70 if HAS_Z and not NORM_BEFORE_GATE: 

71 z = tl.load(Z_base + cols, mask=cols < N).to(tl.float32) 

72 x *= z * tl.sigmoid(z) 

73 if not IS_RMS_NORM: 

74 mean = tl.sum(x, axis=0) / N 

75 tl.store(Mean_base + row, mean) 

76 xbar = tl.where(cols < N, x - mean, 0.0) 

77 var = tl.sum(xbar * xbar, axis=0) / N 

78 else: 

79 xbar = tl.where(cols < N, x, 0.0) 

80 var = tl.sum(xbar * xbar, axis=0) / N 

81 rstd = 1 / tl.sqrt(var + eps) 

82 tl.store(Rstd_base + row, rstd) 

83 # Normalize and apply linear transformation 

84 mask = cols < N 

85 w = tl.load(W_base + cols, mask=mask).to(tl.float32) 

86 if HAS_BIAS: 

87 b = tl.load(B_base + cols, mask=mask).to(tl.float32) 

88 x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd 

89 y = x_hat * w + b if HAS_BIAS else x_hat * w 

90 if HAS_Z and NORM_BEFORE_GATE: 

91 z = tl.load(Z_base + cols, mask=mask).to(tl.float32) 

92 y *= z * tl.sigmoid(z) 

93 # Write output 

94 tl.store(Y_base + cols, y, mask=mask) 

95 

96 

97def _layer_norm_fwd( 

98 x, 

99 weight, 

100 bias, 

101 eps, 

102 z=None, 

103 out=None, 

104 group_size=None, 

105 norm_before_gate=True, 

106 is_rms_norm=False, 

107): 

108 M, N = x.shape 

109 if group_size is None: 

110 group_size = N 

111 assert N % group_size == 0 

112 ngroups = N // group_size 

113 assert x.stride(-1) == 1 

114 if z is not None: 

115 assert z.stride(-1) == 1 

116 assert z.shape == (M, N) 

117 assert weight.shape == (N,) 

118 assert weight.stride(-1) == 1 

119 if bias is not None: 

120 assert bias.stride(-1) == 1 

121 assert bias.shape == (N,) 

122 # allocate output 

123 if out is not None: 

124 assert out.shape == x.shape 

125 else: 

126 out = torch.empty_like(x) 

127 assert out.stride(-1) == 1 

128 mean = ( 

129 torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) 

130 if not is_rms_norm 

131 else None 

132 ) 

133 rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) 

134 # Less than 64KB per feature: enqueue fused kernel 

135 MAX_FUSED_SIZE = 65536 // x.element_size() 

136 BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) 

137 if group_size > BLOCK_N: 

138 raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") 

139 # heuristics for number of warps 

140 num_warps = min(max(BLOCK_N // 256, 1), 8) 

141 grid = (M if M < MAX_CORES else MAX_CORES, ngroups) 

142 with torch.npu.device(x.device.index): 

143 layer_norm_fwd_kernel[grid]( 

144 x, 

145 out, 

146 weight, 

147 bias, 

148 z, 

149 mean, 

150 rstd, 

151 x.stride(0), 

152 out.stride(0), 

153 z.stride(0) if z is not None else 0, 

154 M, 

155 group_size, 

156 eps, 

157 BLOCK_N=BLOCK_N, 

158 NORM_BEFORE_GATE=norm_before_gate, 

159 IS_RMS_NORM=is_rms_norm, 

160 N_CORES=MAX_CORES, 

161 num_warps=num_warps, 

162 ) 

163 return out, mean, rstd 

164 

165 

166class LayerNormFn(torch.autograd.Function): 

167 @staticmethod 

168 def forward( 

169 ctx, 

170 x, 

171 weight, 

172 bias, 

173 z=None, 

174 eps=1e-6, 

175 group_size=None, 

176 norm_before_gate=True, 

177 is_rms_norm=False, 

178 ): 

179 """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" 

180 

181 x_shape_og = x.shape 

182 # reshape input data into 2D tensor 

183 x = x.reshape(-1, x.shape[-1]) 

184 if x.stride(-1) != 1: 

185 x = x.contiguous() 

186 if z is not None: 

187 assert z.shape == x_shape_og 

188 z = z.reshape(-1, z.shape[-1]) 

189 if z.stride(-1) != 1: 

190 z = z.contiguous() 

191 weight = weight.contiguous() 

192 if bias is not None: 

193 bias = bias.contiguous() 

194 y, mean, rstd = _layer_norm_fwd( 

195 x, 

196 weight, 

197 bias, 

198 eps, 

199 z=z, 

200 group_size=group_size, 

201 norm_before_gate=norm_before_gate, 

202 is_rms_norm=is_rms_norm, 

203 ) 

204 return y.reshape(x_shape_og)