Coverage for src/flag_gems/ops/weightnorm.py: 29%

160 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-16 02:02 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry, tl_extra_shim 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13logger = logging.getLogger(__name__) 

14 

15 

16@libentry() 

17@triton.autotune( 

18 configs=runtime.get_tuned_config("weight_norm_kernel_last"), key=["M", "N"] 

19) 

20@triton.jit(do_not_specialize=["eps"]) 

21def weight_norm_kernel_last( 

22 output, 

23 norm, 

24 v, 

25 g, 

26 M, 

27 N, 

28 eps, 

29 BLOCK_ROW_SIZE: tl.constexpr, 

30 BLOCK_COL_SIZE: tl.constexpr, 

31): 

32 tx = tl.arange(0, BLOCK_COL_SIZE)[:, None] 

33 bx = tle.program_id(axis=0) * BLOCK_COL_SIZE 

34 col_offset = bx + tx 

35 col_mask = col_offset < N 

36 

37 ty = tl.arange(0, BLOCK_ROW_SIZE)[None, :] 

38 v_block = tl.zeros([BLOCK_COL_SIZE, BLOCK_ROW_SIZE], dtype=tl.float32) 

39 for base in range(0, M, BLOCK_ROW_SIZE): 

40 row_offset = base + ty 

41 mask = row_offset < M and col_mask 

42 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32) 

43 v_block += v_value * v_value 

44 

45 normalized = tl.sqrt(tl.sum(v_block, axis=1) + eps) 

46 tl.store(norm + col_offset, normalized[:, None], mask=col_mask) 

47 g_value = tl.load(g + col_offset, mask=col_mask).to(tl.float32) 

48 

49 for base in range(0, M, BLOCK_ROW_SIZE): 

50 row_offset = base + ty 

51 mask = row_offset < M and col_mask 

52 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32) 

53 v_vec = v_value / normalized[:, None] 

54 out = v_vec * g_value 

55 tl.store(output + row_offset * N + col_offset, out, mask=mask) 

56 

57 

58@libentry() 

59@triton.autotune( 

60 configs=runtime.get_tuned_config("weight_norm_kernel_first"), key=["M", "N"] 

61) 

62@triton.jit(do_not_specialize=["eps"]) 

63def weight_norm_kernel_first( 

64 output, 

65 norm, 

66 v, 

67 g, 

68 M, 

69 N, 

70 eps, 

71 BLOCK_ROW_SIZE: tl.constexpr, 

72 BLOCK_COL_SIZE: tl.constexpr, 

73): 

74 ty = tl.arange(0, BLOCK_ROW_SIZE)[:, None] 

75 by = tle.program_id(axis=0) * BLOCK_ROW_SIZE 

76 row_offset = by + ty 

77 row_mask = row_offset < M 

78 

79 tx = tl.arange(0, BLOCK_COL_SIZE)[None, :] 

80 v_block = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

81 for base in range(0, N, BLOCK_COL_SIZE): 

82 col_offset = base + tx 

83 mask = col_offset < N and row_mask 

84 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32) 

85 v_block += v_value * v_value 

86 

87 normalized = tl.sqrt(tl.sum(v_block, axis=1) + eps) 

88 tl.store(norm + row_offset, normalized[:, None], mask=row_mask) 

89 g_value = tl.load(g + row_offset, mask=row_mask).to(tl.float32) 

90 

91 for base in range(0, N, BLOCK_COL_SIZE): 

92 col_offset = base + tx 

93 mask = col_offset < N and row_mask 

94 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32) 

95 v_vec = v_value / normalized[:, None] 

96 out = v_vec * g_value 

97 tl.store(output + row_offset * N + col_offset, out, mask=mask) 

98 

99 

100@libentry() 

101@triton.autotune( 

102 configs=runtime.get_tuned_config("weight_norm_kernel_last"), key=["M", "N"] 

103) 

104@triton.jit(do_not_specialize=["eps"]) 

105def weight_norm_bwd_kernel_last( 

106 v_grad, 

107 g_grad, 

108 w, 

109 v, 

110 g, 

111 norm, 

112 M, 

113 N, 

114 eps, 

115 BLOCK_ROW_SIZE: tl.constexpr, 

116 BLOCK_COL_SIZE: tl.constexpr, 

117): 

118 tx = tl.arange(0, BLOCK_COL_SIZE)[:, None] 

119 bx = tle.program_id(axis=0) * BLOCK_COL_SIZE 

120 col_offset = tx + bx 

121 col_mask = col_offset < N 

122 

123 g_value = tl.load(g + col_offset, mask=col_mask).to(tl.float32) 

124 norm_value = tl.load(norm + col_offset, mask=col_mask).to(tl.float32) 

125 norm_1 = 1 / (norm_value + eps) 

126 norm_3 = tl_extra_shim.pow(norm_1, 3) 

127 

128 ty = tl.arange(0, BLOCK_ROW_SIZE)[None, :] 

129 

130 vw_block = tl.zeros([BLOCK_COL_SIZE, BLOCK_ROW_SIZE], dtype=tl.float32) 

131 for base in range(0, M, BLOCK_ROW_SIZE): 

132 row_offset = base + ty 

133 mask = row_offset < M and col_mask 

134 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32) 

135 w_value = tl.load(w + row_offset * N + col_offset, mask=mask).to(tl.float32) 

136 vw_block += v_value * w_value 

137 vw_sum = tl.sum(vw_block, 1)[:, None] 

138 

139 for base in range(0, M, BLOCK_ROW_SIZE): 

140 row_offset = base + ty 

141 mask = row_offset < M and col_mask 

142 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32) 

143 w_value = tl.load(w + row_offset * N + col_offset, mask=mask).to(tl.float32) 

144 v_grad_value = g_value * (w_value * norm_1 - v_value * norm_3 * vw_sum) 

145 tl.store(v_grad + row_offset * N + col_offset, v_grad_value, mask=mask) 

146 

147 g_grad_value = vw_sum / (norm_value + eps) 

148 tl.store(g_grad + col_offset, g_grad_value, mask=col_mask) 

149 

150 

151@libentry() 

152@triton.autotune( 

153 configs=runtime.get_tuned_config("weight_norm_kernel_first"), key=["M", "N"] 

154) 

155@triton.jit(do_not_specialize=["eps"]) 

156def weight_norm_bwd_kernel_first( 

157 v_grad, 

158 g_grad, 

159 w, 

160 v, 

161 g, 

162 norm, 

163 M, 

164 N, 

165 eps, 

166 BLOCK_ROW_SIZE: tl.constexpr, 

167 BLOCK_COL_SIZE: tl.constexpr, 

168): 

169 ty = tl.arange(0, BLOCK_ROW_SIZE)[:, None] 

170 by = tle.program_id(axis=0) * BLOCK_ROW_SIZE 

171 row_offset = by + ty 

172 row_mask = row_offset < M 

173 

174 g_value = tl.load(g + row_offset, mask=row_mask).to(tl.float32) 

175 norm_value = tl.load(norm + row_offset, mask=row_mask).to(tl.float32) 

176 norm_1 = 1 / (norm_value + eps) 

177 norm_3 = tl_extra_shim.pow(norm_1, 3) 

178 

179 tx = tl.arange(0, BLOCK_COL_SIZE)[None, :] 

180 

181 v_block = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

182 for base in range(0, N, BLOCK_COL_SIZE): 

183 col_offset = base + tx 

184 mask = col_offset < N and row_mask 

185 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32) 

186 w_value = tl.load(w + row_offset * N + col_offset, mask=mask).to(tl.float32) 

187 v_block += v_value * w_value 

188 vw_sum = tl.sum(v_block, 1)[:, None] 

189 

190 for base in range(0, N, BLOCK_COL_SIZE): 

191 col_offset = base + tx 

192 mask = col_offset < N and row_mask 

193 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32) 

194 w_value = tl.load(w + row_offset * N + col_offset, mask=mask).to(tl.float32) 

195 v_grad_value = g_value * (w_value * norm_1 - v_value * norm_3 * vw_sum) 

196 tl.store(v_grad + row_offset * N + col_offset, v_grad_value, mask=mask) 

197 

198 g_grad_value = vw_sum / (norm_value + eps) 

199 tl.store(g_grad + row_offset, g_grad_value, mask=row_mask) 

200 

201 

202def weight_norm_interface(v, g, dim=0): 

203 logger.debug("GEMS WEIGHT NORM INTERFACE FORWARD") 

204 v = v.contiguous() 

205 g = g.contiguous() 

206 output = torch.empty_like(v) 

207 norm = torch.empty_like(g) 

208 if dim == 0: 

209 M = v.shape[0] 

210 N = math.prod(v.shape[1:]) 

211 grid = lambda META: (triton.cdiv(M, META["BLOCK_ROW_SIZE"]),) 

212 with torch_device_fn.device(v.device): 

213 weight_norm_kernel_first[grid]( 

214 output, norm, v, g, M, N, eps=torch.finfo(torch.float32).tiny 

215 ) 

216 elif dim == v.ndim - 1: 

217 M = math.prod(v.shape[:-1]) 

218 N = v.shape[dim] 

219 grid = lambda META: (triton.cdiv(N, META["BLOCK_COL_SIZE"]),) 

220 with torch_device_fn.device(v.device): 

221 weight_norm_kernel_last[grid]( 

222 output, norm, v, g, M, N, eps=torch.finfo(torch.float32).tiny 

223 ) 

224 return output, norm 

225 

226 

227def weight_norm_interface_backward(w_grad, saved_v, saved_g, saved_norms, dim): 

228 logger.debug("GEMS WEIGHT NORM INTERFACE BACKWARD") 

229 w_grad = w_grad.contiguous() 

230 saved_v = saved_v.contiguous() 

231 saved_g = saved_g.contiguous() 

232 saved_norms = saved_norms.contiguous() 

233 v_grad = torch.empty_like(saved_v) 

234 g_grad = torch.empty_like(saved_g) 

235 

236 if dim == 0: 

237 M = saved_v.shape[0] 

238 N = math.prod(saved_v.shape[1:]) 

239 grid = lambda META: (triton.cdiv(M, META["BLOCK_ROW_SIZE"]),) 

240 with torch_device_fn.device(saved_v.device): 

241 weight_norm_bwd_kernel_first[grid]( 

242 v_grad, 

243 g_grad, 

244 w_grad, 

245 saved_v, 

246 saved_g, 

247 saved_norms, 

248 M, 

249 N, 

250 eps=torch.finfo(torch.float32).tiny, 

251 ) 

252 elif dim == saved_v.ndim - 1: 

253 M = math.prod(saved_v.shape[:dim]) 

254 N = saved_v.shape[dim] 

255 grid = lambda META: (triton.cdiv(N, META["BLOCK_COL_SIZE"]),) 

256 with torch_device_fn.device(saved_v.device): 

257 weight_norm_bwd_kernel_last[grid]( 

258 v_grad, 

259 g_grad, 

260 w_grad, 

261 saved_v, 

262 saved_g, 

263 saved_norms, 

264 M, 

265 N, 

266 eps=torch.finfo(torch.float32).tiny, 

267 ) 

268 return v_grad, g_grad