Coverage for src/flag_gems/ops/batch_norm.py: 35%

154 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6from torch import Tensor 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry, tl_extra_shim 

11 

12logger = logging.getLogger(__name__) 

13rsqrt = tl_extra_shim.rsqrt 

14 

15 

16def make_3d_for_bn(input: Tensor) -> Tensor: 

17 """ 

18 Converts the input to a 3D view for batch normalization. 

19 

20 Args: 

21 input: Input to render 3D. 

22 

23 Returns: 

24 Input's 3D view. 

25 """ 

26 if input.ndim == 2: 

27 input = input.unsqueeze(-1) 

28 

29 elif input.ndim >= 4: 

30 input = input.flatten(2, -1) 

31 

32 return input 

33 

34 

35# NOTE: This part of the kernel code is copied and modified 

36# from the https://github.com/BobMcDear/attorch codebase. 

37 

38 

39@libentry() 

40@triton.autotune( 

41 configs=runtime.get_tuned_config("batch_norm"), 

42 key=["batch_dim", "spatial_dim"], 

43 restore_value=["running_mean_pointer", "running_var_pointer"], 

44) 

45@triton.heuristics(runtime.get_heuristic_config("batch_norm")) 

46@triton.jit 

47def batch_norm_forward_kernel( 

48 input_pointer, 

49 weight_pointer, 

50 bias_pointer, 

51 mean_pointer, 

52 inv_std_pointer, 

53 output_pointer, 

54 running_mean_pointer, 

55 running_var_pointer, 

56 batch_dim, 

57 spatial_dim, 

58 input_batch_stride, 

59 input_feat_stride, 

60 input_spatial_stride, 

61 output_batch_stride, 

62 output_feat_stride, 

63 output_spatial_stride, 

64 momentum, 

65 eps, 

66 is_train: tl.constexpr, 

67 BLOCK_M: tl.constexpr, 

68 BLOCK_N: tl.constexpr, 

69): 

70 feat_pid = tl.program_id(axis=0) 

71 

72 # traning mode default track_running_stat 

73 if is_train: 

74 mean = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

75 var = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

76 cnt = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) 

77 

78 m_num_steps = tl.cdiv(batch_dim, BLOCK_M) 

79 n_num_steps = tl.cdiv(spatial_dim, BLOCK_N) 

80 

81 for m_step in range(0, m_num_steps): 

82 for n_step in range(0, n_num_steps): 

83 spatial_offset = n_step * BLOCK_N + tl.arange(0, BLOCK_N) 

84 spatial_mask = spatial_offset < spatial_dim 

85 

86 batch_offset = m_step * BLOCK_M + tl.arange(0, BLOCK_M) 

87 batch_mask = batch_offset < batch_dim 

88 

89 curr_input_pointer = ( 

90 input_pointer 

91 + input_feat_stride * feat_pid 

92 + input_batch_stride * batch_offset[:, None] 

93 + input_spatial_stride * spatial_offset[None, :] 

94 ) 

95 

96 mask = batch_mask[:, None] & spatial_mask[None, :] 

97 curr_input = tl.load(curr_input_pointer, mask=mask).to(tl.float32) 

98 

99 step = m_step * n_num_steps + n_step + 1 

100 new_mean = tl.where(mask, mean + (curr_input - mean) / step, mean) 

101 new_var = tl.where( 

102 mask, var + (curr_input - new_mean) * (curr_input - mean), var 

103 ) 

104 cnt += mask.to(tl.int32) 

105 mean = new_mean 

106 var = new_var 

107 

108 final_mean = tl.sum(mean * cnt) / (batch_dim * spatial_dim) 

109 var = tl.sum(var + cnt * (mean - final_mean) * (mean - final_mean)) / ( 

110 batch_dim * spatial_dim 

111 ) 

112 inv_std = rsqrt(var + eps) 

113 mean = final_mean 

114 

115 tl.store(feat_pid + mean_pointer, mean) 

116 tl.store(feat_pid + inv_std_pointer, inv_std) 

117 

118 running_mean_pointer += feat_pid 

119 running_var_pointer += feat_pid 

120 

121 running_mean = tl.load(running_mean_pointer) 

122 running_var = tl.load(running_var_pointer) 

123 

124 n = batch_dim * spatial_dim 

125 tl.store(running_mean_pointer, (1 - momentum) * running_mean + momentum * mean) 

126 tl.store( 

127 running_var_pointer, 

128 (1 - momentum) * running_var + momentum * var * n / (n - 1), 

129 ) 

130 

131 else: 

132 mean = tl.load(feat_pid + running_mean_pointer) 

133 inv_std = rsqrt(tl.load(feat_pid + running_var_pointer) + eps) 

134 

135 if weight_pointer: 

136 weight = tl.load(feat_pid + weight_pointer).to(tl.float32) 

137 else: 

138 weight = 1.0 

139 if bias_pointer: 

140 bias = tl.load(feat_pid + bias_pointer).to(tl.float32) 

141 else: 

142 bias = 0.0 

143 

144 for m_step in range(0, tl.cdiv(batch_dim, BLOCK_M)): 

145 for n_step in range(0, tl.cdiv(spatial_dim, BLOCK_N)): 

146 batch_offset = m_step * BLOCK_M + tl.arange(0, BLOCK_M) 

147 batch_mask = batch_offset < batch_dim 

148 

149 spatial_offset = n_step * BLOCK_N + tl.arange(0, BLOCK_N) 

150 spatial_mask = spatial_offset < spatial_dim 

151 

152 curr_input_pointer = ( 

153 input_pointer 

154 + input_feat_stride * feat_pid 

155 + input_batch_stride * batch_offset[:, None] 

156 + input_spatial_stride * spatial_offset[None, :] 

157 ) 

158 curr_output_pointer = ( 

159 output_pointer 

160 + output_feat_stride * feat_pid 

161 + output_batch_stride * batch_offset[:, None] 

162 + output_spatial_stride * spatial_offset[None, :] 

163 ) 

164 

165 curr_input = tl.load( 

166 curr_input_pointer, mask=batch_mask[:, None] & spatial_mask[None, :] 

167 ).to(tl.float32) 

168 output = weight * (curr_input - mean) * inv_std + bias 

169 

170 tl.store( 

171 curr_output_pointer, 

172 output, 

173 mask=batch_mask[:, None] & spatial_mask[None, :], 

174 ) 

175 

176 

177@libentry() 

178@triton.autotune( 

179 configs=runtime.get_tuned_config("batch_norm"), 

180 key=["batch_dim", "spatial_dim"], 

181) 

182@triton.heuristics(runtime.get_heuristic_config("batch_norm")) 

183@triton.jit 

184def batch_norm_backward_kernel( 

185 output_grad_pointer, 

186 input_pointer, 

187 mean_pointer, 

188 inv_std_pointer, 

189 weight_pointer, 

190 input_grad_pointer, 

191 weight_grad_pointer, 

192 bias_grad_pointer, 

193 batch_dim, 

194 spatial_dim, 

195 output_grad_batch_stride, 

196 output_grad_feat_stride, 

197 output_grad_spatial_stride, 

198 input_batch_stride, 

199 input_feat_stride, 

200 input_spatial_stride, 

201 input_grad_batch_stride, 

202 input_grad_feat_stride, 

203 input_grad_spatial_stride, 

204 input_grad_mask: tl.constexpr, 

205 weight_grad_mask: tl.constexpr, 

206 bias_grad_mask: tl.constexpr, 

207 BLOCK_M: tl.constexpr, 

208 BLOCK_N: tl.constexpr, 

209): 

210 feat_pid = tl.program_id(axis=0) 

211 

212 mean = tl.load(feat_pid + mean_pointer).to(tl.float32) 

213 inv_std = tl.load(feat_pid + inv_std_pointer).to(tl.float32) 

214 

215 term1 = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

216 term2 = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

217 

218 for m_step in range(0, tl.cdiv(batch_dim, BLOCK_M)): 

219 for n_step in range(0, tl.cdiv(spatial_dim, BLOCK_N)): 

220 batch_offset = m_step * BLOCK_M + tl.arange(0, BLOCK_M) 

221 batch_mask = batch_offset < batch_dim 

222 

223 spatial_offset = n_step * BLOCK_N + tl.arange(0, BLOCK_N) 

224 spatial_mask = spatial_offset < spatial_dim 

225 

226 curr_output_grad_pointer = ( 

227 output_grad_pointer 

228 + output_grad_feat_stride * feat_pid 

229 + output_grad_batch_stride * batch_offset[:, None] 

230 + output_grad_spatial_stride * spatial_offset[None, :] 

231 ) 

232 curr_input_pointer = ( 

233 input_pointer 

234 + input_feat_stride * feat_pid 

235 + input_batch_stride * batch_offset[:, None] 

236 + input_spatial_stride * spatial_offset[None, :] 

237 ) 

238 

239 mask = batch_mask[:, None] & spatial_mask[None, :] 

240 curr_input = tl.load(curr_input_pointer, mask=mask).to(tl.float32) 

241 

242 curr_pre_lin = (curr_input - mean) * inv_std 

243 curr_output_grad = tl.load(curr_output_grad_pointer, mask=mask).to( 

244 tl.float32 

245 ) 

246 

247 term1 += curr_pre_lin * curr_output_grad 

248 term2 += curr_output_grad 

249 

250 term1 = tl.sum(term1) 

251 term2 = tl.sum(term2) 

252 

253 if weight_grad_mask: 

254 tl.store(feat_pid + weight_grad_pointer, term1) 

255 if bias_grad_mask: 

256 tl.store(feat_pid + bias_grad_pointer, term2) 

257 

258 if not input_grad_mask: 

259 return 

260 

261 if weight_pointer: 

262 weight = tl.load(feat_pid + weight_pointer).to(tl.float32) 

263 else: 

264 weight = 1.0 

265 

266 count = batch_dim * spatial_dim 

267 

268 for m_step in range(0, tl.cdiv(batch_dim, BLOCK_M)): 

269 for n_step in range(0, tl.cdiv(spatial_dim, BLOCK_N)): 

270 batch_offset = m_step * BLOCK_M + tl.arange(0, BLOCK_M) 

271 batch_mask = batch_offset < batch_dim 

272 

273 spatial_offset = n_step * BLOCK_N + tl.arange(0, BLOCK_N) 

274 spatial_mask = spatial_offset < spatial_dim 

275 

276 curr_output_grad_pointer = ( 

277 output_grad_pointer 

278 + output_grad_feat_stride * feat_pid 

279 + output_grad_batch_stride * batch_offset[:, None] 

280 + output_grad_spatial_stride * spatial_offset[None, :] 

281 ) 

282 curr_input_pointer = ( 

283 input_pointer 

284 + input_feat_stride * feat_pid 

285 + input_batch_stride * batch_offset[:, None] 

286 + input_spatial_stride * spatial_offset[None, :] 

287 ) 

288 curr_input_grad_pointer = ( 

289 input_grad_pointer 

290 + input_grad_feat_stride * feat_pid 

291 + input_grad_batch_stride * batch_offset[:, None] 

292 + input_grad_spatial_stride * spatial_offset[None, :] 

293 ) 

294 

295 curr_input = tl.load( 

296 curr_input_pointer, mask=batch_mask[:, None] & spatial_mask[None, :] 

297 ).to(tl.float32) 

298 curr_pre_lin = (curr_input - mean) * inv_std 

299 curr_output_grad = tl.load( 

300 curr_output_grad_pointer, 

301 mask=batch_mask[:, None] & spatial_mask[None, :], 

302 ).to(tl.float32) 

303 curr_input_grad = ( 

304 inv_std 

305 * weight 

306 * (curr_output_grad - (term1 * curr_pre_lin + term2) / count) 

307 ) 

308 tl.store( 

309 curr_input_grad_pointer, 

310 curr_input_grad, 

311 mask=batch_mask[:, None] & spatial_mask[None, :], 

312 ) 

313 

314 

315def batch_norm( 

316 input: Tensor, 

317 weight=None, 

318 bias=None, 

319 running_mean=None, # self.running_mean if not self.training or self.track_running_state else None 

320 running_var=None, 

321 training=False, # (self.running_mean is None) and (self.running_var is None) 

322 momentum=0.1, 

323 eps=1e-05, 

324): 

325 logger.debug("GEMS BATCHNORM FORWARD") 

326 

327 input_3d = make_3d_for_bn(input) 

328 

329 batch_dim, feat_dim, spatial_dim = input_3d.shape 

330 output = torch.empty_like(input_3d) 

331 

332 mean = torch.empty(feat_dim, device=input.device, dtype=input.dtype) 

333 inv_std = torch.empty(feat_dim, device=input.device, dtype=input.dtype) 

334 

335 running_mean = input if running_mean is None else running_mean 

336 running_var = input if running_var is None else running_var 

337 

338 # Launches 1D grid where each program operates over one feature. 

339 with torch_device_fn.device(input.device): 

340 batch_norm_forward_kernel[(feat_dim,)]( 

341 input_3d, 

342 weight, 

343 bias, 

344 mean, 

345 inv_std, 

346 output, 

347 running_mean, 

348 running_var, 

349 batch_dim, 

350 spatial_dim, 

351 *input_3d.stride(), 

352 *output.stride(), 

353 momentum, 

354 eps, 

355 is_train=training, 

356 ) 

357 

358 return output.view_as(input), mean, inv_std 

359 

360 

361def batch_norm_backward( 

362 grad_out, 

363 input, 

364 weight=None, 

365 running_mean=None, 

366 running_var=None, 

367 save_mean=None, 

368 save_invstd=None, 

369 train=False, 

370 eps=1e-05, 

371 output_mask=None, 

372): 

373 logger.debug("GEMS BATCHNORM BACKWARD") 

374 input_3d = make_3d_for_bn(input) 

375 output_grad_3d = make_3d_for_bn(grad_out) 

376 

377 batch_dim, feat_dim, spatial_dim = input_3d.shape 

378 

379 if output_mask[0]: 

380 input_grad = torch.empty_like(input_3d) 

381 else: 

382 input_grad = None 

383 if output_mask[1]: 

384 weight_grad = torch.empty((feat_dim,), dtype=input.dtype, device=input.device) 

385 else: 

386 weight_grad = None 

387 if output_mask[2]: 

388 bias_grad = torch.empty((feat_dim,), dtype=input.dtype, device=input.device) 

389 else: 

390 bias_grad = None 

391 

392 # Launches 1D grid where each program operates over one feature. 

393 with torch_device_fn.device(input.device): 

394 batch_norm_backward_kernel[(feat_dim,)]( 

395 output_grad_3d, 

396 input_3d, 

397 save_mean, 

398 save_invstd, 

399 weight, 

400 input_grad, 

401 weight_grad, 

402 bias_grad, 

403 batch_dim, 

404 spatial_dim, 

405 *output_grad_3d.stride(), 

406 *input_3d.stride(), 

407 *input_grad.stride(), 

408 *output_mask, 

409 ) 

410 

411 # Pads output with None because a gradient is necessary for 

412 # all input arguments. 

413 return ( 

414 input_grad.view_as(input), 

415 weight_grad, 

416 bias_grad, 

417 )