Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/weightnorm.py: 0%

171 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-20 02:31 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8# from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

14 

15 

16def weight_norm_kernel_last_block_row(args): 

17 return 1 

18 import builtins 

19 

20 return builtins.min(args["M"], 8192) 

21 

22 

23def weight_norm_kernel_last_block_col(args): 

24 # return 1 

25 return triton.next_power_of_2(triton.cdiv(args["N"], 12)) 

26 

27 

28@libentry() 

29# @triton.autotune( 

30# configs=runtime.get_tuned_config("weight_norm_kernel_last"), key=["M", "N"] 

31# ) 

32@triton.heuristics( 

33 values={ 

34 "BLOCK_ROW_SIZE": weight_norm_kernel_last_block_row, 

35 "BLOCK_COL_SIZE": weight_norm_kernel_last_block_col, 

36 }, 

37) 

38@triton.jit(do_not_specialize=["eps"]) 

39def weight_norm_kernel_last( 

40 output, 

41 norm, 

42 v, 

43 g, 

44 M: tl.constexpr, 

45 N: tl.constexpr, 

46 eps, 

47 BLOCK_ROW_SIZE: tl.constexpr, 

48 BLOCK_COL_SIZE: tl.constexpr, 

49): 

50 tx = tl.arange(0, BLOCK_COL_SIZE)[:, None] 

51 bx = tle.program_id(axis=0) * BLOCK_COL_SIZE 

52 col_offset = bx + tx 

53 col_mask = col_offset < N 

54 

55 ty = tl.arange(0, BLOCK_ROW_SIZE)[None, :] 

56 v_block = tl.zeros([BLOCK_COL_SIZE, BLOCK_ROW_SIZE], dtype=tl.float32) 

57 for base in range(0, M, BLOCK_ROW_SIZE): 

58 row_offset = base + ty 

59 mask = row_offset < M and col_mask 

60 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32) 

61 v_block += v_value * v_value 

62 

63 normalized = tl.sqrt(tl.sum(v_block, axis=1) + eps) 

64 tl.store(norm + col_offset, normalized[:, None], mask=col_mask) 

65 g_value = tl.load(g + col_offset, mask=col_mask).to(tl.float32) 

66 

67 for base in range(0, M, BLOCK_ROW_SIZE): 

68 row_offset = base + ty 

69 mask = row_offset < M and col_mask 

70 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32) 

71 v_vec = v_value / normalized[:, None] 

72 out = v_vec * g_value 

73 tl.store(output + row_offset * N + col_offset, out, mask=mask) 

74 

75 

76def weight_norm_kernel_first_block_row(args): 

77 return triton.next_power_of_2(triton.cdiv(args["M"], 12)) 

78 

79 

80def weight_norm_kernel_first_block_col(args): 

81 return 1 

82 

83 

84@libentry() 

85# @triton.autotune( 

86# configs=runtime.get_tuned_config("weight_norm_kernel_first"), key=["M", "N"] 

87# ) 

88@triton.heuristics( 

89 values={ 

90 "BLOCK_ROW_SIZE": weight_norm_kernel_first_block_row, 

91 "BLOCK_COL_SIZE": weight_norm_kernel_first_block_col, 

92 }, 

93) 

94@triton.jit(do_not_specialize=["eps"]) 

95def weight_norm_kernel_first( 

96 output, 

97 norm, 

98 v, 

99 g, 

100 M, 

101 N, 

102 eps, 

103 BLOCK_ROW_SIZE: tl.constexpr, 

104 BLOCK_COL_SIZE: tl.constexpr, 

105): 

106 ty = tl.arange(0, BLOCK_ROW_SIZE)[:, None] 

107 by = tle.program_id(axis=0) * BLOCK_ROW_SIZE 

108 row_offset = by + ty 

109 row_mask = row_offset < M 

110 

111 tx = tl.arange(0, BLOCK_COL_SIZE)[None, :] 

112 v_block = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

113 for base in range(0, N, BLOCK_COL_SIZE): 

114 col_offset = base + tx 

115 mask = col_offset < N and row_mask 

116 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32) 

117 v_block += v_value * v_value 

118 

119 normalized = tl.sqrt(tl.sum(v_block, axis=1) + eps) 

120 tl.store(norm + row_offset, normalized[:, None], mask=row_mask) 

121 g_value = tl.load(g + row_offset, mask=row_mask).to(tl.float32) 

122 

123 for base in range(0, N, BLOCK_COL_SIZE): 

124 col_offset = base + tx 

125 mask = col_offset < N and row_mask 

126 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32) 

127 v_vec = v_value / normalized[:, None] 

128 out = v_vec * g_value 

129 tl.store(output + row_offset * N + col_offset, out, mask=mask) 

130 

131 

132def heur_block_m_weight_norm_bwd_kernel_last(args): 

133 return 1 

134 

135 

136def heur_block_n_weight_norm_bwd_kernel_last(args): 

137 return triton.next_power_of_2(triton.cdiv(args["N"], 12)) 

138 

139 

140@libentry() 

141# @triton.autotune( 

142# configs=runtime.get_tuned_config("weight_norm_kernel_last"), key=["M", "N"] 

143# ) 

144@triton.heuristics( 

145 values={ 

146 "BLOCK_ROW_SIZE": heur_block_m_weight_norm_bwd_kernel_last, 

147 "BLOCK_COL_SIZE": heur_block_n_weight_norm_bwd_kernel_last, 

148 }, 

149) 

150@triton.jit(do_not_specialize=["eps"]) 

151def weight_norm_bwd_kernel_last( 

152 v_grad, 

153 g_grad, 

154 w, 

155 v, 

156 g, 

157 norm, 

158 M: tl.constexpr, 

159 N: tl.constexpr, 

160 eps, 

161 BLOCK_ROW_SIZE: tl.constexpr, 

162 BLOCK_COL_SIZE: tl.constexpr, 

163): 

164 tx = tl.arange(0, BLOCK_COL_SIZE)[:, None] 

165 bx = tle.program_id(axis=0) * BLOCK_COL_SIZE 

166 col_offset = tx + bx 

167 col_mask = col_offset < N 

168 

169 g_value = tl.load(g + col_offset, mask=col_mask).to(tl.float32) 

170 norm_value = tl.load(norm + col_offset, mask=col_mask).to(tl.float32) 

171 

172 ty = tl.arange(0, BLOCK_ROW_SIZE)[None, :] 

173 

174 vw_block = tl.zeros([BLOCK_COL_SIZE, BLOCK_ROW_SIZE], dtype=tl.float32) 

175 for base in range(0, M, BLOCK_ROW_SIZE): 

176 row_offset = base + ty 

177 mask = row_offset < M and col_mask 

178 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32) 

179 w_value = tl.load(w + row_offset * N + col_offset, mask=mask).to(tl.float32) 

180 vw_block += v_value * w_value 

181 vw_sum = tl.sum(vw_block, 1)[:, None] 

182 

183 for base in range(0, M, BLOCK_ROW_SIZE): 

184 row_offset = base + ty 

185 mask = row_offset < M and col_mask 

186 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32) 

187 w_value = tl.load(w + row_offset * N + col_offset, mask=mask).to(tl.float32) 

188 v_grad_value = g_value * ( 

189 w_value / (norm_value + eps) 

190 - v_value / (norm_value * norm_value * norm_value + eps) * vw_sum 

191 ) 

192 tl.store(v_grad + row_offset * N + col_offset, v_grad_value, mask=mask) 

193 

194 g_grad_value = vw_sum / (norm_value + eps) 

195 tl.store(g_grad + col_offset, g_grad_value, mask=col_mask) 

196 

197 

198def heur_block_m_weight_norm_bwd_kernel_first(args): 

199 return triton.next_power_of_2(triton.cdiv(args["M"], 12)) 

200 

201 

202def heur_block_n_weight_norm_bwd_kernel_first(args): 

203 return 1 

204 

205 

206@libentry() 

207# @triton.autotune( 

208# configs=runtime.get_tuned_config("weight_norm_kernel_first"), key=["M", "N"] 

209# ) 

210@triton.heuristics( 

211 values={ 

212 "BLOCK_ROW_SIZE": heur_block_m_weight_norm_bwd_kernel_first, 

213 "BLOCK_COL_SIZE": heur_block_n_weight_norm_bwd_kernel_first, 

214 }, 

215) 

216@triton.jit(do_not_specialize=["eps"]) 

217def weight_norm_bwd_kernel_first( 

218 v_grad, 

219 g_grad, 

220 w, 

221 v, 

222 g, 

223 norm, 

224 M, 

225 N, 

226 eps, 

227 BLOCK_ROW_SIZE: tl.constexpr, 

228 BLOCK_COL_SIZE: tl.constexpr, 

229): 

230 ty = tl.arange(0, BLOCK_ROW_SIZE)[:, None] 

231 by = tle.program_id(axis=0) * BLOCK_ROW_SIZE 

232 row_offset = by + ty 

233 row_mask = row_offset < M 

234 

235 g_value = tl.load(g + row_offset, mask=row_mask).to(tl.float32) 

236 norm_value = tl.load(norm + row_offset, mask=row_mask).to(tl.float32) 

237 

238 tx = tl.arange(0, BLOCK_COL_SIZE)[None, :] 

239 

240 v_block = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

241 for base in range(0, N, BLOCK_COL_SIZE): 

242 col_offset = base + tx 

243 mask = col_offset < N and row_mask 

244 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32) 

245 w_value = tl.load(w + row_offset * N + col_offset, mask=mask).to(tl.float32) 

246 v_block += v_value * w_value 

247 vw_sum = tl.sum(v_block, 1)[:, None] 

248 

249 for base in range(0, N, BLOCK_COL_SIZE): 

250 col_offset = base + tx 

251 mask = col_offset < N and row_mask 

252 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32) 

253 w_value = tl.load(w + row_offset * N + col_offset, mask=mask).to(tl.float32) 

254 v_grad_value = g_value * ( 

255 w_value / (norm_value + eps) 

256 - v_value / (norm_value * norm_value * norm_value + eps) * vw_sum 

257 ) 

258 tl.store(v_grad + row_offset * N + col_offset, v_grad_value, mask=mask) 

259 

260 g_grad_value = vw_sum / (norm_value + eps) 

261 tl.store(g_grad + row_offset, g_grad_value, mask=row_mask) 

262 

263 

264def weight_norm_interface(v, g, dim=0): 

265 logger.debug("GEMS WEIGHT NORM INTERFACE FORWARD") 

266 v = v.contiguous() 

267 g = g.contiguous() 

268 output = torch.empty_like(v) 

269 norm = torch.empty_like(g) 

270 if dim == 0: 

271 M = v.shape[0] 

272 N = math.prod(v.shape[1:]) 

273 grid = lambda META: (triton.cdiv(M, META["BLOCK_ROW_SIZE"]),) 

274 with torch_device_fn.device(v.device): 

275 weight_norm_kernel_first[grid]( 

276 output, norm, v, g, M, N, eps=torch.finfo(torch.float32).tiny 

277 ) 

278 elif dim == v.ndim - 1: 

279 M = math.prod(v.shape[:-1]) 

280 N = v.shape[dim] 

281 grid = lambda META: (triton.cdiv(N, META["BLOCK_COL_SIZE"]),) 

282 with torch_device_fn.device(v.device): 

283 weight_norm_kernel_last[grid]( 

284 output, norm, v, g, M, N, eps=torch.finfo(torch.float32).tiny 

285 ) 

286 return output, norm 

287 

288 

289def weight_norm_interface_backward(w_grad, saved_v, saved_g, saved_norms, dim): 

290 logger.debug("GEMS WEIGHT NORM INTERFACE BACKWARD") 

291 w_grad = w_grad.contiguous() 

292 saved_v = saved_v.contiguous() 

293 saved_g = saved_g.contiguous() 

294 saved_norms = saved_norms.contiguous() 

295 v_grad = torch.empty_like(saved_v) 

296 g_grad = torch.empty_like(saved_g) 

297 

298 if dim == 0: 

299 M = saved_v.shape[0] 

300 N = math.prod(saved_v.shape[1:]) 

301 grid = lambda META: (triton.cdiv(M, META["BLOCK_ROW_SIZE"]),) 

302 with torch_device_fn.device(saved_v.device): 

303 weight_norm_bwd_kernel_first[grid]( 

304 v_grad, 

305 g_grad, 

306 w_grad, 

307 saved_v, 

308 saved_g, 

309 saved_norms, 

310 M, 

311 N, 

312 eps=torch.finfo(torch.float32).tiny, 

313 ) 

314 elif dim == saved_v.ndim - 1: 

315 M = math.prod(saved_v.shape[:dim]) 

316 N = saved_v.shape[dim] 

317 grid = lambda META: (triton.cdiv(N, META["BLOCK_COL_SIZE"]),) 

318 with torch_device_fn.device(saved_v.device): 

319 weight_norm_bwd_kernel_last[grid]( 

320 v_grad, 

321 g_grad, 

322 w_grad, 

323 saved_v, 

324 saved_g, 

325 saved_norms, 

326 M, 

327 N, 

328 eps=torch.finfo(torch.float32).tiny, 

329 ) 

330 return v_grad, g_grad