Coverage for src/flag_gems/runtime/backend/_cambricon/ops/weightnorm.py: 0%

228 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-15 02:11 +0800

1import copy 

2import logging 

3import math 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems import runtime 

10from flag_gems.runtime import torch_device_fn 

11from flag_gems.utils import libentry 

12 

13from ..utils import MAX_NRAM_SIZE, TOTAL_CORE_NUM 

14 

15logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

16MAX_N = 31744 

17 

18 

19@libentry() 

20@triton.autotune( 

21 configs=runtime.get_tuned_config("weight_norm_kernel_last"), key=["M", "N"] 

22) 

23@triton.jit(do_not_specialize=["eps"]) 

24def weight_norm_kernel_last( 

25 output, 

26 norm, 

27 v, 

28 g, 

29 M, 

30 N, 

31 eps, 

32 BLOCK_ROW_SIZE: tl.constexpr, 

33 BLOCK_COL_SIZE: tl.constexpr, 

34): 

35 tx = tl.arange(0, BLOCK_COL_SIZE)[:, None] 

36 bx = tl.program_id(axis=0) * BLOCK_COL_SIZE 

37 col_offset = bx + tx 

38 col_mask = col_offset < N 

39 

40 ty = tl.arange(0, BLOCK_ROW_SIZE)[None, :] 

41 v_block = tl.zeros([BLOCK_COL_SIZE, BLOCK_ROW_SIZE], dtype=tl.float32) 

42 for base in range(0, M, BLOCK_ROW_SIZE): 

43 row_offset = base + ty 

44 mask = row_offset < M and col_mask 

45 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32) 

46 v_block += v_value * v_value 

47 

48 normalized = tl.sqrt(tl.sum(v_block, axis=1) + eps) 

49 tl.store(norm + col_offset, normalized[:, None], mask=col_mask) 

50 g_value = tl.load(g + col_offset, mask=col_mask).to(tl.float32) 

51 

52 for base in range(0, M, BLOCK_ROW_SIZE): 

53 row_offset = base + ty 

54 mask = row_offset < M and col_mask 

55 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32) 

56 v_vec = v_value / normalized[:, None] 

57 out = v_vec * g_value 

58 tl.store(output + row_offset * N + col_offset, out, mask=mask) 

59 

60 

61def config_prune_for_first(configs, named_args, **kwargs): 

62 M = named_args["M"] 

63 N = named_args["N"] 

64 configs_map = {} 

65 # When N is less than MAX_C_MLU_SOFTMAX_FORWARD, no reduction loops 

66 for config in configs: 

67 kw = config.kwargs 

68 BLOCK_ROW_SIZE, BLOCK_COL_SIZE, num_warps, num_stages = ( 

69 kw["BLOCK_ROW_SIZE"], 

70 kw["BLOCK_COL_SIZE"], 

71 config.num_warps, 

72 config.num_stages, 

73 ) 

74 if N < MAX_N: 

75 config = copy.deepcopy(config) 

76 BLOCK_COL_SIZE = config.kwargs["BLOCK_COL_SIZE"] = N 

77 m_per_core = math.ceil(M / TOTAL_CORE_NUM) 

78 nram_usage = (3 * BLOCK_COL_SIZE + 1) * m_per_core * 4 

79 if nram_usage < MAX_NRAM_SIZE: 

80 BLOCK_ROW_SIZE = config.kwargs["BLOCK_ROW_SIZE"] = m_per_core 

81 num_stages = config.num_stages = 1 

82 key = (BLOCK_ROW_SIZE, BLOCK_COL_SIZE, num_warps, num_stages) 

83 configs_map.setdefault(key, config) 

84 else: 

85 max_block_m_without_pipe = ( 

86 MAX_NRAM_SIZE // 4 // (3 * BLOCK_COL_SIZE + 1) 

87 ) 

88 BLOCK_ROW_SIZE = config.kwargs[ 

89 "BLOCK_ROW_SIZE" 

90 ] = max_block_m_without_pipe 

91 num_stages = config.num_stages = 1 

92 key = (BLOCK_ROW_SIZE, BLOCK_COL_SIZE, num_warps, num_stages) 

93 configs_map.setdefault(key, config) 

94 

95 config = copy.deepcopy(config) 

96 max_block_m_without_pipe = ( 

97 MAX_NRAM_SIZE // 4 // (6 * BLOCK_COL_SIZE + 4) 

98 ) 

99 num_stages = config.num_stages = 3 

100 key = (BLOCK_ROW_SIZE, BLOCK_COL_SIZE, num_warps, num_stages) 

101 configs_map.setdefault(key, config) 

102 key = (BLOCK_ROW_SIZE, BLOCK_COL_SIZE, num_warps, num_stages) 

103 # Only keep one config for the same key 

104 configs_map.setdefault(key, config) 

105 pruned_configs = [] 

106 for k, v in configs_map.items(): 

107 pruned_configs.append(v) 

108 return pruned_configs 

109 

110 

111def tile_mode_for_first(args): 

112 one_tile_m = args["BLOCK_ROW_SIZE"] * TOTAL_CORE_NUM >= args["M"] 

113 one_tile_n = args["BLOCK_COL_SIZE"] >= args["N"] 

114 if one_tile_n and one_tile_m: 

115 return 0 

116 elif one_tile_n and not one_tile_m: 

117 return 1 

118 else: 

119 return 2 

120 

121 

122@libentry() 

123@triton.autotune( 

124 configs=runtime.get_tuned_config("weight_norm_kernel_first"), 

125 key=["M", "N"], 

126 prune_configs_by={"early_config_prune": config_prune_for_first}, 

127) 

128@triton.heuristics( 

129 values={ 

130 "TILE_MODE": lambda args: tile_mode_for_first(args), 

131 }, 

132) 

133@triton.jit(do_not_specialize=["eps"]) 

134def weight_norm_kernel_first( 

135 output, 

136 norm, 

137 v, 

138 g, 

139 M, 

140 N, 

141 eps, 

142 BLOCK_ROW_SIZE: tl.constexpr, 

143 BLOCK_COL_SIZE: tl.constexpr, 

144 TILE_MODE: tl.constexpr, 

145): 

146 pid_m = tl.program_id(0) 

147 pnum = tl.num_programs(axis=0) 

148 split_m = tl.cdiv(M, pnum) 

149 m_start = pid_m * split_m 

150 if TILE_MODE == 0: 

151 m_offset = pid_m * BLOCK_ROW_SIZE + tl.arange(0, BLOCK_ROW_SIZE) 

152 n_offset = tl.arange(0, BLOCK_COL_SIZE) 

153 offset = m_offset[:, None] * N + n_offset[None, :] 

154 mask = m_offset[:, None] < M 

155 v_value = tl.load(v + offset, mask=mask).to(tl.float32) 

156 normalized = tl.sqrt(tl.sum(v_value * v_value, axis=1) + eps) 

157 tl.store(norm + m_offset[:, None], normalized[:, None], mask=mask) 

158 g_value = tl.load(g + m_offset[:, None], mask=mask).to(tl.float32) 

159 v_vec = v_value / normalized[:, None] 

160 out = v_vec * g_value 

161 tl.store(output + offset, out, mask=mask) 

162 elif TILE_MODE == 1: 

163 for m_idx in range(0, split_m, BLOCK_ROW_SIZE): 

164 m_offset = m_start + m_idx + tl.arange(0, BLOCK_ROW_SIZE) 

165 n_offset = tl.arange(0, BLOCK_COL_SIZE) 

166 offset = m_offset[:, None] * N + n_offset[None, :] 

167 mask = m_offset[:, None] < M 

168 v_value = tl.load(v + offset, mask=mask).to(tl.float32) 

169 normalized = tl.sqrt(tl.sum(v_value * v_value, axis=1) + eps) 

170 tl.store(norm + m_offset[:, None], normalized[:, None], mask=mask) 

171 g_value = tl.load(g + m_offset[:, None], mask=mask).to(tl.float32) 

172 v_vec = v_value / normalized[:, None] 

173 out = v_vec * g_value 

174 tl.store(output + offset, out, mask=mask) 

175 else: 

176 for m_idx in range(0, split_m, BLOCK_ROW_SIZE): 

177 m_offset = m_start + m_idx + tl.arange(0, BLOCK_ROW_SIZE) 

178 m_mask = m_offset[:, None] < M 

179 v_block = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

180 for start_n in range(0, N, BLOCK_COL_SIZE): 

181 n_offset = start_n + tl.arange(0, BLOCK_COL_SIZE) 

182 offset = m_offset[:, None] * N + n_offset[None, :] 

183 mask = m_mask and n_offset[None, :] < N 

184 v_value = tl.load(v + offset, mask=mask).to(tl.float32) 

185 v_block += v_value * v_value 

186 

187 normalized = tl.sqrt(tl.sum(v_block, axis=1) + eps) 

188 tl.store(norm + m_offset[:, None], normalized[:, None], mask=m_mask) 

189 g_value = tl.load(g + m_offset[:, None], mask=m_mask).to(tl.float32) 

190 

191 for start_n in range(0, N, BLOCK_COL_SIZE): 

192 n_offset = start_n + tl.arange(0, BLOCK_COL_SIZE) 

193 offset = m_offset[:, None] * N + n_offset[None, :] 

194 mask = m_mask and n_offset[None, :] < N 

195 v_value = tl.load(v + offset, mask=mask).to(tl.float32) 

196 v_vec = v_value / normalized[:, None] 

197 out = v_vec * g_value 

198 tl.store(output + offset, out, mask=mask) 

199 

200 

201@libentry() 

202@triton.autotune( 

203 configs=runtime.get_tuned_config("weight_norm_kernel_last"), key=["M", "N"] 

204) 

205@triton.jit(do_not_specialize=["eps"]) 

206def weight_norm_bwd_kernel_last( 

207 v_grad, 

208 g_grad, 

209 w, 

210 v, 

211 g, 

212 norm, 

213 M, 

214 N, 

215 eps, 

216 BLOCK_ROW_SIZE: tl.constexpr, 

217 BLOCK_COL_SIZE: tl.constexpr, 

218): 

219 tx = tl.arange(0, BLOCK_COL_SIZE)[:, None] 

220 bx = tl.program_id(axis=0) * BLOCK_COL_SIZE 

221 col_offset = tx + bx 

222 col_mask = col_offset < N 

223 

224 g_value = tl.load(g + col_offset, mask=col_mask).to(tl.float32) 

225 norm_value = tl.load(norm + col_offset, mask=col_mask).to(tl.float32) 

226 

227 ty = tl.arange(0, BLOCK_ROW_SIZE)[None, :] 

228 

229 vw_block = tl.zeros([BLOCK_COL_SIZE, BLOCK_ROW_SIZE], dtype=tl.float32) 

230 for base in range(0, M, BLOCK_ROW_SIZE): 

231 row_offset = base + ty 

232 mask = row_offset < M and col_mask 

233 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32) 

234 w_value = tl.load(w + row_offset * N + col_offset, mask=mask).to(tl.float32) 

235 vw_block += v_value * w_value 

236 vw_sum = tl.sum(vw_block, 1)[:, None] 

237 

238 for base in range(0, M, BLOCK_ROW_SIZE): 

239 row_offset = base + ty 

240 mask = row_offset < M and col_mask 

241 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32) 

242 w_value = tl.load(w + row_offset * N + col_offset, mask=mask).to(tl.float32) 

243 v_grad_value = g_value * ( 

244 w_value / (norm_value + eps) 

245 - v_value / (norm_value * norm_value * norm_value + eps) * vw_sum 

246 ) 

247 tl.store(v_grad + row_offset * N + col_offset, v_grad_value, mask=mask) 

248 

249 g_grad_value = vw_sum / (norm_value + eps) 

250 tl.store(g_grad + col_offset, g_grad_value, mask=col_mask) 

251 

252 

253@libentry() 

254@triton.autotune( 

255 configs=runtime.get_tuned_config("weight_norm_kernel_first"), key=["M", "N"] 

256) 

257@triton.jit(do_not_specialize=["eps"]) 

258def weight_norm_bwd_kernel_first( 

259 v_grad, 

260 g_grad, 

261 w, 

262 v, 

263 g, 

264 norm, 

265 M, 

266 N, 

267 eps, 

268 BLOCK_ROW_SIZE: tl.constexpr, 

269 BLOCK_COL_SIZE: tl.constexpr, 

270): 

271 ty = tl.arange(0, BLOCK_ROW_SIZE)[:, None] 

272 by = tl.program_id(axis=0) * BLOCK_ROW_SIZE 

273 row_offset = by + ty 

274 row_mask = row_offset < M 

275 

276 g_value = tl.load(g + row_offset, mask=row_mask).to(tl.float32) 

277 norm_value = tl.load(norm + row_offset, mask=row_mask).to(tl.float32) 

278 

279 tx = tl.arange(0, BLOCK_COL_SIZE)[None, :] 

280 

281 v_block = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

282 for base in range(0, N, BLOCK_COL_SIZE): 

283 col_offset = base + tx 

284 mask = col_offset < N and row_mask 

285 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32) 

286 w_value = tl.load(w + row_offset * N + col_offset, mask=mask).to(tl.float32) 

287 v_block += v_value * w_value 

288 vw_sum = tl.sum(v_block, 1)[:, None] 

289 

290 for base in range(0, N, BLOCK_COL_SIZE): 

291 col_offset = base + tx 

292 mask = col_offset < N and row_mask 

293 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32) 

294 w_value = tl.load(w + row_offset * N + col_offset, mask=mask).to(tl.float32) 

295 v_grad_value = g_value * ( 

296 w_value / (norm_value + eps) 

297 - v_value / (norm_value * norm_value * norm_value + eps) * vw_sum 

298 ) 

299 tl.store(v_grad + row_offset * N + col_offset, v_grad_value, mask=mask) 

300 

301 g_grad_value = vw_sum / (norm_value + eps) 

302 tl.store(g_grad + row_offset, g_grad_value, mask=row_mask) 

303 

304 

305def weight_norm_interface(v, g, dim=0): 

306 logger.debug("GEMS_CAMBRICON WEIGHTNORM FORWARD") 

307 v = v.contiguous() 

308 g = g.contiguous() 

309 output = torch.empty_like(v) 

310 norm = torch.empty_like(g) 

311 if dim == 0: 

312 M = v.shape[0] 

313 N = math.prod(v.shape[1:]) 

314 with torch_device_fn.device(v.device): 

315 weight_norm_kernel_first[TOTAL_CORE_NUM, 1, 1]( 

316 output, norm, v, g, M, N, eps=torch.finfo(torch.float32).tiny 

317 ) 

318 elif dim == v.ndim - 1: 

319 M = math.prod(v.shape[:-1]) 

320 N = v.shape[dim] 

321 grid = lambda META: (triton.cdiv(N, META["BLOCK_COL_SIZE"]),) 

322 with torch_device_fn.device(v.device): 

323 weight_norm_kernel_last[grid]( 

324 output, norm, v, g, M, N, eps=torch.finfo(torch.float32).tiny 

325 ) 

326 return output, norm 

327 

328 

329def weight_norm_interface_backward(w_grad, saved_v, saved_g, saved_norms, dim): 

330 logger.debug("GEMS_CAMBRICON WEIGHTNORM BACKWARD") 

331 w_grad = w_grad.contiguous() 

332 saved_v = saved_v.contiguous() 

333 saved_g = saved_g.contiguous() 

334 saved_norms = saved_norms.contiguous() 

335 v_grad = torch.empty_like(saved_v) 

336 g_grad = torch.empty_like(saved_g) 

337 

338 if dim == 0: 

339 M = saved_v.shape[0] 

340 N = math.prod(saved_v.shape[1:]) 

341 grid = lambda META: (triton.cdiv(M, META["BLOCK_ROW_SIZE"]),) 

342 with torch_device_fn.device(saved_v.device): 

343 weight_norm_bwd_kernel_first[grid]( 

344 v_grad, 

345 g_grad, 

346 w_grad, 

347 saved_v, 

348 saved_g, 

349 saved_norms, 

350 M, 

351 N, 

352 eps=torch.finfo(torch.float32).tiny, 

353 ) 

354 elif dim == saved_v.ndim - 1: 

355 M = math.prod(saved_v.shape[:dim]) 

356 N = saved_v.shape[dim] 

357 grid = lambda META: (triton.cdiv(N, META["BLOCK_COL_SIZE"]),) 

358 with torch_device_fn.device(saved_v.device): 

359 weight_norm_bwd_kernel_last[grid]( 

360 v_grad, 

361 g_grad, 

362 w_grad, 

363 saved_v, 

364 saved_g, 

365 saved_norms, 

366 M, 

367 N, 

368 eps=torch.finfo(torch.float32).tiny, 

369 ) 

370 return v_grad, g_grad