Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/batch_norm.py: 0%

168 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6from torch import Tensor 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry, tl_extra_shim 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13rsqrt = tl_extra_shim.rsqrt 

14 

15 

16def make_3d_for_bn(input: Tensor) -> Tensor: 

17 """ 

18 Converts the input to a 3D view for batch normalization. 

19 

20 Args: 

21 input: Input to render 3D. 

22 

23 Returns: 

24 Input's 3D view. 

25 """ 

26 if input.ndim == 2: 

27 input = input.unsqueeze(-1) 

28 

29 elif input.ndim >= 4: 

30 input = input.flatten(2, -1) 

31 

32 return input 

33 

34 

35# NOTE: This part of the kernel code is copied and modified 

36# from the https://github.com/BobMcDear/attorch codebase. 

37 

38 

39@libentry() 

40# @triton.autotune( 

41# configs=runtime.get_tuned_config("batch_norm"), 

42# key=["batch_dim", "spatial_dim"], 

43# restore_value=["running_mean_pointer", "running_var_pointer"], 

44# ) 

45@triton.heuristics(runtime.get_heuristic_config("batch_norm")) 

46@triton.jit 

47def batch_norm_forward_kernel( 

48 input_pointer, 

49 weight_pointer, 

50 bias_pointer, 

51 mean_pointer, 

52 inv_std_pointer, 

53 output_pointer, 

54 running_mean_pointer, 

55 running_var_pointer, 

56 batch_dim, 

57 spatial_dim, 

58 input_batch_stride, 

59 input_feat_stride, 

60 input_spatial_stride, 

61 output_batch_stride, 

62 output_feat_stride, 

63 output_spatial_stride, 

64 momentum, 

65 eps, 

66 is_train: tl.constexpr, 

67 BLOCK_M: tl.constexpr, 

68 BLOCK_N: tl.constexpr, 

69): 

70 feat_pid = tl.program_id(axis=0) 

71 

72 # traning mode default track_running_stat 

73 if is_train: 

74 mean = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

75 var = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

76 cnt = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) 

77 

78 m_num_steps = tl.cdiv(batch_dim, BLOCK_M) 

79 n_num_steps = tl.cdiv(spatial_dim, BLOCK_N) 

80 

81 for m_step in range(0, m_num_steps): 

82 for n_step in range(0, n_num_steps): 

83 spatial_offset = n_step * BLOCK_N + tl.arange(0, BLOCK_N) 

84 spatial_mask = spatial_offset < spatial_dim 

85 

86 batch_offset = m_step * BLOCK_M + tl.arange(0, BLOCK_M) 

87 batch_mask = batch_offset < batch_dim 

88 

89 curr_input_pointer = ( 

90 input_pointer 

91 + input_feat_stride * feat_pid 

92 + input_batch_stride * batch_offset[:, None] 

93 + input_spatial_stride * spatial_offset[None, :] 

94 ) 

95 

96 mask = batch_mask[:, None] & spatial_mask[None, :] 

97 curr_input = tl.load(curr_input_pointer, mask=mask).to(tl.float32) 

98 

99 step = m_step * n_num_steps + n_step + 1 

100 new_mean = tl.where(mask, mean + (curr_input - mean) / step, mean) 

101 new_var = tl.where( 

102 mask, var + (curr_input - new_mean) * (curr_input - mean), var 

103 ) 

104 cnt += mask.to(tl.int32) 

105 mean = new_mean 

106 var = new_var 

107 

108 final_mean = tl.sum(mean * cnt) / (batch_dim * spatial_dim) 

109 var = tl.sum(var + cnt * (mean - final_mean) * (mean - final_mean)) / ( 

110 batch_dim * spatial_dim 

111 ) 

112 inv_std = rsqrt(var + eps) 

113 mean = final_mean 

114 

115 tl.store(feat_pid + mean_pointer, mean) 

116 tl.store(feat_pid + inv_std_pointer, inv_std) 

117 

118 running_mean_pointer += feat_pid 

119 running_var_pointer += feat_pid 

120 

121 running_mean = tl.load(running_mean_pointer) 

122 running_var = tl.load(running_var_pointer) 

123 

124 n = batch_dim * spatial_dim 

125 tl.store(running_mean_pointer, (1 - momentum) * running_mean + momentum * mean) 

126 tl.store( 

127 running_var_pointer, 

128 (1 - momentum) * running_var + momentum * var * n / (n - 1), 

129 ) 

130 

131 else: 

132 mean = tl.load(feat_pid + running_mean_pointer) 

133 inv_std = rsqrt(tl.load(feat_pid + running_var_pointer) + eps) 

134 

135 if weight_pointer: 

136 weight = tl.load(feat_pid + weight_pointer).to(tl.float32) 

137 else: 

138 weight = 1.0 

139 if bias_pointer: 

140 bias = tl.load(feat_pid + bias_pointer).to(tl.float32) 

141 else: 

142 bias = 0.0 

143 

144 for m_step in range(0, tl.cdiv(batch_dim, BLOCK_M)): 

145 for n_step in range(0, tl.cdiv(spatial_dim, BLOCK_N)): 

146 batch_offset = m_step * BLOCK_M + tl.arange(0, BLOCK_M) 

147 batch_mask = batch_offset < batch_dim 

148 

149 spatial_offset = n_step * BLOCK_N + tl.arange(0, BLOCK_N) 

150 spatial_mask = spatial_offset < spatial_dim 

151 

152 curr_input_pointer = ( 

153 input_pointer 

154 + input_feat_stride * feat_pid 

155 + input_batch_stride * batch_offset[:, None] 

156 + input_spatial_stride * spatial_offset[None, :] 

157 ) 

158 curr_output_pointer = ( 

159 output_pointer 

160 + output_feat_stride * feat_pid 

161 + output_batch_stride * batch_offset[:, None] 

162 + output_spatial_stride * spatial_offset[None, :] 

163 ) 

164 

165 curr_input = tl.load( 

166 curr_input_pointer, mask=batch_mask[:, None] & spatial_mask[None, :] 

167 ).to(tl.float32) 

168 output = weight * (curr_input - mean) * inv_std + bias 

169 

170 tl.store( 

171 curr_output_pointer, 

172 output, 

173 mask=batch_mask[:, None] & spatial_mask[None, :], 

174 ) 

175 

176 

177def batch_norm_heur_block_m(args): 

178 return min(64, triton.next_power_of_2(args["batch_dim"])) 

179 

180 

181def batch_norm_heur_block_n(args): 

182 # A maximum of 16384 elements are loaded at once. 

183 BLOCK_M = batch_norm_heur_block_m(args) 

184 BLOCK_N = triton.next_power_of_2(args["spatial_dim"]) 

185 return min(BLOCK_N, max(1, 2**14 // BLOCK_M)) 

186 

187 

188@libentry() 

189# @triton.autotune( 

190# configs=runtime.get_tuned_config("batch_norm"), 

191# key=["batch_dim", "spatial_dim"], 

192# ) 

193@triton.heuristics( 

194 values={ 

195 "BLOCK_M": batch_norm_heur_block_m, 

196 "BLOCK_N": batch_norm_heur_block_n, 

197 }, 

198) 

199# @triton.heuristics(runtime.get_heuristic_config("batch_norm")) 

200@triton.jit 

201def batch_norm_backward_kernel( 

202 output_grad_pointer, 

203 input_pointer, 

204 mean_pointer, 

205 inv_std_pointer, 

206 weight_pointer, 

207 input_grad_pointer, 

208 weight_grad_pointer, 

209 bias_grad_pointer, 

210 batch_dim, 

211 spatial_dim, 

212 output_grad_batch_stride, 

213 output_grad_feat_stride, 

214 output_grad_spatial_stride, 

215 input_batch_stride, 

216 input_feat_stride, 

217 input_spatial_stride, 

218 input_grad_batch_stride, 

219 input_grad_feat_stride, 

220 input_grad_spatial_stride, 

221 input_grad_mask: tl.constexpr, 

222 weight_grad_mask: tl.constexpr, 

223 bias_grad_mask: tl.constexpr, 

224 BLOCK_M: tl.constexpr, 

225 BLOCK_N: tl.constexpr, 

226): 

227 feat_pid = tl.program_id(axis=0) 

228 

229 mean = tl.load(feat_pid + mean_pointer).to(tl.float32) 

230 inv_std = tl.load(feat_pid + inv_std_pointer).to(tl.float32) 

231 

232 term1 = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

233 term2 = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

234 

235 for m_step in range(0, tl.cdiv(batch_dim, BLOCK_M)): 

236 batch_offset = m_step * BLOCK_M + tl.arange(0, BLOCK_M) 

237 batch_mask = batch_offset < batch_dim 

238 

239 for n_step in range(0, tl.cdiv(spatial_dim, BLOCK_N)): 

240 spatial_offset = n_step * BLOCK_N + tl.arange(0, BLOCK_N) 

241 spatial_mask = spatial_offset < spatial_dim 

242 

243 curr_output_grad_pointer = ( 

244 output_grad_pointer 

245 + output_grad_feat_stride * feat_pid 

246 + output_grad_batch_stride * batch_offset[:, None] 

247 + output_grad_spatial_stride * spatial_offset[None, :] 

248 ) 

249 curr_input_pointer = ( 

250 input_pointer 

251 + input_feat_stride * feat_pid 

252 + input_batch_stride * batch_offset[:, None] 

253 + input_spatial_stride * spatial_offset[None, :] 

254 ) 

255 

256 mask = batch_mask[:, None] & spatial_mask[None, :] 

257 curr_input = tl.load(curr_input_pointer, mask=mask, other=0).to(tl.float32) 

258 

259 curr_pre_lin = ((curr_input - mean) * inv_std).to(tl.float32) 

260 curr_output_grad = tl.load( 

261 curr_output_grad_pointer, mask=mask, other=0.0 

262 ).to(tl.float32) 

263 

264 term1 += curr_pre_lin * curr_output_grad 

265 term2 += curr_output_grad 

266 

267 term1 = tl.sum(term1) 

268 term2 = tl.sum(term2) 

269 

270 if weight_grad_mask: 

271 tl.store(feat_pid + weight_grad_pointer, term1) 

272 if bias_grad_mask: 

273 tl.store(feat_pid + bias_grad_pointer, term2) 

274 

275 if not input_grad_mask: 

276 return 

277 

278 if weight_pointer: 

279 weight = tl.load(feat_pid + weight_pointer).to(tl.float32) 

280 else: 

281 weight = 1.0 

282 weight = weight.to(tl.float32) 

283 

284 count = batch_dim * spatial_dim 

285 

286 for m_step in range(0, tl.cdiv(batch_dim, BLOCK_M)): 

287 for n_step in range(0, tl.cdiv(spatial_dim, BLOCK_N)): 

288 batch_offset = m_step * BLOCK_M + tl.arange(0, BLOCK_M) 

289 batch_mask = batch_offset < batch_dim 

290 

291 spatial_offset = n_step * BLOCK_N + tl.arange(0, BLOCK_N) 

292 spatial_mask = spatial_offset < spatial_dim 

293 

294 curr_output_grad_pointer = ( 

295 output_grad_pointer 

296 + output_grad_feat_stride * feat_pid 

297 + output_grad_batch_stride * batch_offset[:, None] 

298 + output_grad_spatial_stride * spatial_offset[None, :] 

299 ) 

300 curr_input_pointer = ( 

301 input_pointer 

302 + input_feat_stride * feat_pid 

303 + input_batch_stride * batch_offset[:, None] 

304 + input_spatial_stride * spatial_offset[None, :] 

305 ) 

306 curr_input_grad_pointer = ( 

307 input_grad_pointer 

308 + input_grad_feat_stride * feat_pid 

309 + input_grad_batch_stride * batch_offset[:, None] 

310 + input_grad_spatial_stride * spatial_offset[None, :] 

311 ) 

312 

313 curr_input = tl.load( 

314 curr_input_pointer, mask=batch_mask[:, None] & spatial_mask[None, :] 

315 ).to(tl.float32) 

316 curr_pre_lin = (curr_input - mean) * inv_std 

317 curr_output_grad = tl.load( 

318 curr_output_grad_pointer, 

319 mask=batch_mask[:, None] & spatial_mask[None, :], 

320 ).to(tl.float32) 

321 curr_input_grad = ( 

322 inv_std 

323 * weight 

324 * (curr_output_grad - (term1 * curr_pre_lin + term2) / count) 

325 ) 

326 tl.store( 

327 curr_input_grad_pointer, 

328 curr_input_grad, 

329 mask=batch_mask[:, None] & spatial_mask[None, :], 

330 ) 

331 

332 

333def batch_norm( 

334 input: Tensor, 

335 weight=None, 

336 bias=None, 

337 running_mean=None, # self.running_mean if not self.training or self.track_running_state else None 

338 running_var=None, 

339 training=False, # (self.running_mean is None) and (self.running_var is None) 

340 momentum=0.1, 

341 eps=1e-05, 

342): 

343 logger.debug("GEMS BATCHNORM FORWARD") 

344 

345 input_3d_i = make_3d_for_bn(input) 

346 m, n, k = input_3d_i.shape 

347 input_3d_f = input_3d_i.permute(0, 2, 1).reshape(-1, n) 

348 input_3d = make_3d_for_bn(input_3d_f) 

349 # input_3d = make_3d_for_bn(input) 

350 

351 batch_dim, feat_dim, spatial_dim = input_3d.shape 

352 output = torch.empty_like(input_3d) 

353 

354 mean = torch.empty(feat_dim, device=input.device, dtype=input.dtype) 

355 inv_std = torch.empty(feat_dim, device=input.device, dtype=input.dtype) 

356 

357 running_mean = input if running_mean is None else running_mean 

358 running_var = input if running_var is None else running_var 

359 

360 # Launches 1D grid where each program operates over one feature. 

361 with torch_device_fn.device(input.device): 

362 batch_norm_forward_kernel[(feat_dim,)]( 

363 input_3d, 

364 weight, 

365 bias, 

366 mean, 

367 inv_std, 

368 output, 

369 running_mean, 

370 running_var, 

371 batch_dim, 

372 spatial_dim, 

373 *input_3d.stride(), 

374 *output.stride(), 

375 momentum, 

376 eps, 

377 is_train=training, 

378 buffer_size_limit=2048, 

379 ) 

380 

381 output_reshaped = output.reshape(m, k, n).permute(0, 2, 1) 

382 return output_reshaped.view_as(input), mean, inv_std 

383 

384 

385def batch_norm_backward( 

386 grad_out, 

387 input, 

388 weight=None, 

389 running_mean=None, 

390 running_var=None, 

391 save_mean=None, 

392 save_invstd=None, 

393 train=False, 

394 eps=1e-05, 

395 output_mask=None, 

396): 

397 logger.debug("GEMS BATCHNORM BACKWARD") 

398 input_3d_i = make_3d_for_bn(input) 

399 m, n, k = input_3d_i.shape 

400 input_3d_f = input_3d_i.permute(0, 2, 1).reshape(-1, n) 

401 input_3d = make_3d_for_bn(input_3d_f) 

402 

403 output_grad_3d_i = make_3d_for_bn(grad_out) 

404 output_grad_3d_f = output_grad_3d_i.permute(0, 2, 1).reshape(-1, n) 

405 output_grad_3d = make_3d_for_bn(output_grad_3d_f) 

406 

407 batch_dim, feat_dim, spatial_dim = input_3d.shape 

408 

409 if output_mask[0]: 

410 input_grad = torch.empty_like(input_3d) 

411 else: 

412 input_grad = None 

413 if output_mask[1]: 

414 weight_grad = torch.empty((feat_dim,), dtype=input.dtype, device=input.device) 

415 else: 

416 weight_grad = None 

417 if output_mask[2]: 

418 bias_grad = torch.empty((feat_dim,), dtype=input.dtype, device=input.device) 

419 else: 

420 bias_grad = None 

421 

422 # Launches 1D grid where each program operates over one feature. 

423 with torch_device_fn.device(input.device): 

424 batch_norm_backward_kernel[(feat_dim, 1, 1)]( 

425 output_grad_3d, 

426 input_3d, 

427 save_mean, 

428 save_invstd, 

429 weight, 

430 input_grad, 

431 weight_grad, 

432 bias_grad, 

433 batch_dim, 

434 spatial_dim, 

435 *output_grad_3d.stride(), 

436 *input_3d.stride(), 

437 *input_grad.stride(), 

438 *output_mask, 

439 buffer_size_limit=2048, 

440 ) 

441 

442 # Pads output with None because a gradient is necessary for 

443 # all input arguments. 

444 return ( 

445 input_grad.reshape(m, k, n).permute(0, 2, 1).view_as(input), 

446 weight_grad, 

447 bias_grad, 

448 )