Coverage for src/flag_gems/runtime/backend/_mthreads/ops/batch_norm.py: 0%

283 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-13 10:08 +0800

1import logging 

2from typing import Tuple 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import tl_extra_shim 

10 

11logger = logging.getLogger( 

12 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}' 

13) 

14rsqrt = tl_extra_shim.rsqrt 

15 

16 

17def _make_3d_for_bn(input: torch.Tensor) -> torch.Tensor: 

18 if input.ndim == 2: 

19 return input.unsqueeze(-1) 

20 if input.ndim >= 4: 

21 return input.flatten(2, -1) 

22 return input 

23 

24 

25def _block_size(numel: int) -> int: 

26 if numel >= 524288: 

27 return 512 

28 if numel >= 1024: 

29 return 256 

30 if numel >= 256: 

31 return 128 

32 return 64 

33 

34 

35def _num_warps(block: int) -> int: 

36 if block >= 512: 

37 return 8 

38 if block >= 256: 

39 return 4 

40 if block >= 128: 

41 return 4 

42 return 2 

43 

44 

45_REDUCE_BLOCK = 256 

46_FALLBACK_ELEMENTS = 4096 

47_NATIVE_SWITCH_ELEMENTS = 32768 

48_NATIVE_CACHE = {} 

49 

50 

51def _get_temp_stats(device, dtype, feat_dim): 

52 key = (device, dtype, feat_dim) 

53 cached = _NATIVE_CACHE.get(key) 

54 if cached is None or cached[0].numel() != feat_dim: 

55 rm = torch.zeros((feat_dim,), device=device, dtype=dtype) 

56 rv = torch.ones((feat_dim,), device=device, dtype=dtype) 

57 _NATIVE_CACHE[key] = (rm, rv) 

58 return _NATIVE_CACHE[key] 

59 

60 

61@triton.jit 

62def _bn_forward_stats_stage1( 

63 input_ptr, 

64 partial_sum_ptr, 

65 partial_sq_ptr, 

66 batch_dim, 

67 spatial_dim, 

68 input_batch_stride, 

69 input_feat_stride, 

70 input_spatial_stride, 

71 num_blocks, 

72 BLOCK: tl.constexpr, 

73): 

74 feat = tl.program_id(0) 

75 block_id = tl.program_id(1) 

76 

77 offset = block_id * BLOCK + tl.arange(0, BLOCK) 

78 total = batch_dim * spatial_dim 

79 mask = offset < total 

80 

81 batch_idx = offset // spatial_dim 

82 spatial_idx = offset - batch_idx * spatial_dim 

83 

84 ptrs = ( 

85 input_ptr 

86 + feat * input_feat_stride 

87 + batch_idx * input_batch_stride 

88 + spatial_idx * input_spatial_stride 

89 ) 

90 values = tl.load(ptrs, mask=mask, other=0.0).to(tl.float32) 

91 

92 tl.store(partial_sum_ptr + feat * num_blocks + block_id, tl.sum(values, axis=0)) 

93 tl.store( 

94 partial_sq_ptr + feat * num_blocks + block_id, 

95 tl.sum(values * values, axis=0), 

96 ) 

97 

98 

99@triton.jit 

100def _bn_reduce_partial_kernel( 

101 partial_sum_ptr, 

102 partial_sq_ptr, 

103 sum_ptr, 

104 sum_sq_ptr, 

105 num_blocks, 

106 BLOCK: tl.constexpr, 

107): 

108 feat = tl.program_id(0) 

109 block_id = tl.program_id(1) 

110 

111 offset = block_id * BLOCK + tl.arange(0, BLOCK) 

112 mask = offset < num_blocks 

113 

114 partial_sum = tl.load( 

115 partial_sum_ptr + feat * num_blocks + offset, mask=mask, other=0.0 

116 ) 

117 partial_sq = tl.load( 

118 partial_sq_ptr + feat * num_blocks + offset, mask=mask, other=0.0 

119 ) 

120 

121 tl.atomic_add(sum_ptr + feat, tl.sum(partial_sum, axis=0)) 

122 tl.atomic_add(sum_sq_ptr + feat, tl.sum(partial_sq, axis=0)) 

123 

124 

125@triton.jit 

126def _bn_fused_train_kernel( 

127 input_ptr, 

128 weight_ptr, 

129 bias_ptr, 

130 output_ptr, 

131 mean_ptr, 

132 inv_std_ptr, 

133 running_mean_ptr, 

134 running_var_ptr, 

135 batch_dim, 

136 spatial_dim, 

137 input_batch_stride, 

138 input_feat_stride, 

139 input_spatial_stride, 

140 output_batch_stride, 

141 output_feat_stride, 

142 output_spatial_stride, 

143 momentum, 

144 eps, 

145 update_running: tl.constexpr, 

146 BLOCK: tl.constexpr, 

147): 

148 feat = tl.program_id(0) 

149 offsets = tl.arange(0, BLOCK) 

150 total = batch_dim * spatial_dim 

151 num_tiles = tl.cdiv(total, BLOCK) 

152 

153 sum_val = tl.zeros((), dtype=tl.float32) 

154 sum_sq_val = tl.zeros((), dtype=tl.float32) 

155 

156 for tile in range(0, num_tiles): 

157 idx = tile * BLOCK + offsets 

158 mask = idx < total 

159 batch_idx = idx // spatial_dim 

160 spatial_idx = idx - batch_idx * spatial_dim 

161 ptrs = ( 

162 input_ptr 

163 + feat * input_feat_stride 

164 + batch_idx * input_batch_stride 

165 + spatial_idx * input_spatial_stride 

166 ) 

167 vals = tl.load(ptrs, mask=mask, other=0.0).to(tl.float32) 

168 sum_val += tl.sum(vals, axis=0) 

169 sum_sq_val += tl.sum(vals * vals, axis=0) 

170 

171 total_f = tl.full((), total, tl.float32) 

172 mean = sum_val / total_f 

173 var = tl.maximum(sum_sq_val / total_f - mean * mean, 0.0) 

174 inv_std = rsqrt(var + eps) 

175 

176 tl.store(mean_ptr + feat, mean) 

177 tl.store(inv_std_ptr + feat, inv_std) 

178 

179 if update_running: 

180 running_mean = tl.load(running_mean_ptr + feat) 

181 running_var = tl.load(running_var_ptr + feat) 

182 unbiased_var = var * total_f / tl.maximum(total_f - 1, 1.0) 

183 tl.store( 

184 running_mean_ptr + feat, (1 - momentum) * running_mean + momentum * mean 

185 ) 

186 tl.store( 

187 running_var_ptr + feat, 

188 (1 - momentum) * running_var + momentum * unbiased_var, 

189 ) 

190 

191 weight = tl.load(weight_ptr + feat).to(tl.float32) if weight_ptr else 1.0 

192 bias = tl.load(bias_ptr + feat).to(tl.float32) if bias_ptr else 0.0 

193 

194 for tile in range(0, num_tiles): 

195 idx = tile * BLOCK + offsets 

196 mask = idx < total 

197 batch_idx = idx // spatial_dim 

198 spatial_idx = idx - batch_idx * spatial_dim 

199 input_ptrs = ( 

200 input_ptr 

201 + feat * input_feat_stride 

202 + batch_idx * input_batch_stride 

203 + spatial_idx * input_spatial_stride 

204 ) 

205 output_ptrs = ( 

206 output_ptr 

207 + feat * output_feat_stride 

208 + batch_idx * output_batch_stride 

209 + spatial_idx * output_spatial_stride 

210 ) 

211 vals = tl.load(input_ptrs, mask=mask, other=0.0).to(tl.float32) 

212 out = (vals - mean) * inv_std * weight + bias 

213 tl.store(output_ptrs, out, mask=mask) 

214 

215 

216@triton.jit 

217def _bn_forward_finalize_kernel( 

218 sum_ptr, 

219 sum_sq_ptr, 

220 mean_ptr, 

221 inv_std_ptr, 

222 running_mean_ptr, 

223 running_var_ptr, 

224 total_elems, 

225 momentum, 

226 eps, 

227 update_running: tl.constexpr, 

228): 

229 feat = tl.program_id(0) 

230 sum_val = tl.load(sum_ptr + feat) 

231 sum_sq_val = tl.load(sum_sq_ptr + feat) 

232 

233 total = tl.full((), total_elems, tl.float32) 

234 mean = sum_val / total 

235 var = tl.maximum(sum_sq_val / total - mean * mean, 0.0) 

236 inv_std = rsqrt(var + eps) 

237 

238 tl.store(mean_ptr + feat, mean) 

239 tl.store(inv_std_ptr + feat, inv_std) 

240 

241 if update_running: 

242 if running_mean_ptr and running_var_ptr: 

243 running_mean = tl.load(running_mean_ptr + feat) 

244 running_var = tl.load(running_var_ptr + feat) 

245 unbiased_var = var * total / tl.maximum(total - 1.0, 1.0) 

246 tl.store( 

247 running_mean_ptr + feat, 

248 (1 - momentum) * running_mean + momentum * mean, 

249 ) 

250 tl.store( 

251 running_var_ptr + feat, 

252 (1 - momentum) * running_var + momentum * unbiased_var, 

253 ) 

254 

255 

256@triton.jit 

257def _bn_forward_apply_kernel( 

258 input_ptr, 

259 weight_ptr, 

260 bias_ptr, 

261 mean_ptr, 

262 inv_std_ptr, 

263 output_ptr, 

264 batch_dim, 

265 spatial_dim, 

266 input_batch_stride, 

267 input_feat_stride, 

268 input_spatial_stride, 

269 output_batch_stride, 

270 output_feat_stride, 

271 output_spatial_stride, 

272 BLOCK: tl.constexpr, 

273): 

274 feat = tl.program_id(0) 

275 block_id = tl.program_id(1) 

276 

277 offset = block_id * BLOCK + tl.arange(0, BLOCK) 

278 total = batch_dim * spatial_dim 

279 mask = offset < total 

280 

281 batch_idx = offset // spatial_dim 

282 spatial_idx = offset - batch_idx * spatial_dim 

283 

284 mean = tl.load(mean_ptr + feat).to(tl.float32) 

285 inv_std = tl.load(inv_std_ptr + feat).to(tl.float32) 

286 

287 weight = tl.load(weight_ptr + feat).to(tl.float32) if weight_ptr else 1.0 

288 bias = tl.load(bias_ptr + feat).to(tl.float32) if bias_ptr else 0.0 

289 

290 input_ptrs = ( 

291 input_ptr 

292 + feat * input_feat_stride 

293 + batch_idx * input_batch_stride 

294 + spatial_idx * input_spatial_stride 

295 ) 

296 output_ptrs = ( 

297 output_ptr 

298 + feat * output_feat_stride 

299 + batch_idx * output_batch_stride 

300 + spatial_idx * output_spatial_stride 

301 ) 

302 

303 values = tl.load(input_ptrs, mask=mask, other=0.0).to(tl.float32) 

304 output = (values - mean) * inv_std * weight + bias 

305 tl.store(output_ptrs, output, mask=mask) 

306 

307 

308@triton.jit 

309def _bn_backward_reduce_kernel( 

310 output_grad_ptr, 

311 input_ptr, 

312 mean_ptr, 

313 inv_std_ptr, 

314 partial_sum_ptr, 

315 partial_sum_xhat_ptr, 

316 batch_dim, 

317 spatial_dim, 

318 output_grad_batch_stride, 

319 output_grad_feat_stride, 

320 output_grad_spatial_stride, 

321 input_batch_stride, 

322 input_feat_stride, 

323 input_spatial_stride, 

324 num_blocks, 

325 BLOCK: tl.constexpr, 

326): 

327 feat = tl.program_id(0) 

328 block_id = tl.program_id(1) 

329 

330 offset = block_id * BLOCK + tl.arange(0, BLOCK) 

331 total = batch_dim * spatial_dim 

332 mask = offset < total 

333 

334 batch_idx = offset // spatial_dim 

335 spatial_idx = offset - batch_idx * spatial_dim 

336 

337 mean = tl.load(mean_ptr + feat).to(tl.float32) 

338 inv_std = tl.load(inv_std_ptr + feat).to(tl.float32) 

339 

340 grad_ptrs = ( 

341 output_grad_ptr 

342 + feat * output_grad_feat_stride 

343 + batch_idx * output_grad_batch_stride 

344 + spatial_idx * output_grad_spatial_stride 

345 ) 

346 input_ptrs = ( 

347 input_ptr 

348 + feat * input_feat_stride 

349 + batch_idx * input_batch_stride 

350 + spatial_idx * input_spatial_stride 

351 ) 

352 

353 dy = tl.load(grad_ptrs, mask=mask, other=0.0).to(tl.float32) 

354 x = tl.load(input_ptrs, mask=mask, other=0.0).to(tl.float32) 

355 x_hat = (x - mean) * inv_std 

356 

357 tl.store(partial_sum_ptr + feat * num_blocks + block_id, tl.sum(dy, axis=0)) 

358 tl.store( 

359 partial_sum_xhat_ptr + feat * num_blocks + block_id, 

360 tl.sum(dy * x_hat, axis=0), 

361 ) 

362 

363 

364@triton.jit 

365def _bn_backward_reduce_partial_kernel( 

366 partial_sum_ptr, 

367 partial_sum_xhat_ptr, 

368 sum_dy_ptr, 

369 sum_dy_xhat_ptr, 

370 num_blocks, 

371 BLOCK: tl.constexpr, 

372): 

373 feat = tl.program_id(0) 

374 block_id = tl.program_id(1) 

375 

376 offset = block_id * BLOCK + tl.arange(0, BLOCK) 

377 mask = offset < num_blocks 

378 

379 partial_sum = tl.load( 

380 partial_sum_ptr + feat * num_blocks + offset, mask=mask, other=0.0 

381 ) 

382 partial_sum_xhat = tl.load( 

383 partial_sum_xhat_ptr + feat * num_blocks + offset, mask=mask, other=0.0 

384 ) 

385 

386 tl.atomic_add(sum_dy_ptr + feat, tl.sum(partial_sum, axis=0)) 

387 tl.atomic_add(sum_dy_xhat_ptr + feat, tl.sum(partial_sum_xhat, axis=0)) 

388 

389 

390@triton.jit 

391def _bn_backward_input_kernel( 

392 output_grad_ptr, 

393 input_ptr, 

394 mean_ptr, 

395 inv_std_ptr, 

396 weight_ptr, 

397 sum_dy_ptr, 

398 sum_dy_xhat_ptr, 

399 input_grad_ptr, 

400 batch_dim, 

401 spatial_dim, 

402 output_grad_batch_stride, 

403 output_grad_feat_stride, 

404 output_grad_spatial_stride, 

405 input_batch_stride, 

406 input_feat_stride, 

407 input_spatial_stride, 

408 input_grad_batch_stride, 

409 input_grad_feat_stride, 

410 input_grad_spatial_stride, 

411 BLOCK: tl.constexpr, 

412): 

413 feat = tl.program_id(0) 

414 block_id = tl.program_id(1) 

415 

416 offset = block_id * BLOCK + tl.arange(0, BLOCK) 

417 total = batch_dim * spatial_dim 

418 mask = offset < total 

419 

420 batch_idx = offset // spatial_dim 

421 spatial_idx = offset - batch_idx * spatial_dim 

422 

423 mean = tl.load(mean_ptr + feat).to(tl.float32) 

424 inv_std = tl.load(inv_std_ptr + feat).to(tl.float32) 

425 sum_dy = tl.load(sum_dy_ptr + feat) 

426 sum_dy_xhat = tl.load(sum_dy_xhat_ptr + feat) 

427 count = tl.full((), total, tl.float32) 

428 

429 weight = tl.load(weight_ptr + feat).to(tl.float32) if weight_ptr else 1.0 

430 

431 grad_ptrs = ( 

432 output_grad_ptr 

433 + feat * output_grad_feat_stride 

434 + batch_idx * output_grad_batch_stride 

435 + spatial_idx * output_grad_spatial_stride 

436 ) 

437 input_ptrs = ( 

438 input_ptr 

439 + feat * input_feat_stride 

440 + batch_idx * input_batch_stride 

441 + spatial_idx * input_spatial_stride 

442 ) 

443 input_grad_ptrs = ( 

444 input_grad_ptr 

445 + feat * input_grad_feat_stride 

446 + batch_idx * input_grad_batch_stride 

447 + spatial_idx * input_grad_spatial_stride 

448 ) 

449 

450 dy = tl.load(grad_ptrs, mask=mask, other=0.0).to(tl.float32) 

451 x = tl.load(input_ptrs, mask=mask, other=0.0).to(tl.float32) 

452 x_hat = (x - mean) * inv_std 

453 

454 term = (dy - sum_dy / count - x_hat * sum_dy_xhat / count) * inv_std * weight 

455 tl.store(input_grad_ptrs, term, mask=mask) 

456 

457 

458@triton.jit 

459def _bn_backward_param_kernel( 

460 sum_dy_ptr, 

461 sum_dy_xhat_ptr, 

462 weight_grad_ptr, 

463 bias_grad_ptr, 

464 weight_grad_mask: tl.constexpr, 

465 bias_grad_mask: tl.constexpr, 

466): 

467 feat = tl.program_id(0) 

468 if weight_grad_mask: 

469 tl.store(weight_grad_ptr + feat, tl.load(sum_dy_xhat_ptr + feat)) 

470 if bias_grad_mask: 

471 tl.store(bias_grad_ptr + feat, tl.load(sum_dy_ptr + feat)) 

472 

473 

474def _get_launch_config(batch_dim: int, spatial_dim: int) -> Tuple[int, int, int]: 

475 total = batch_dim * spatial_dim 

476 block = _block_size(total) 

477 num_blocks = triton.cdiv(total, block) 

478 return block, num_blocks, _num_warps(block) 

479 

480 

481def batch_norm( 

482 input: torch.Tensor, 

483 weight=None, 

484 bias=None, 

485 running_mean=None, 

486 running_var=None, 

487 training: bool = False, 

488 momentum: float = 0.1, 

489 eps: float = 1e-05, 

490): 

491 logger.debug("GEMS_MTHREADS BATCHNORM FORWARD") 

492 input_3d = _make_3d_for_bn(input) 

493 batch_dim, feat_dim, spatial_dim = input_3d.shape 

494 total = batch_dim * spatial_dim 

495 

496 if total <= _NATIVE_SWITCH_ELEMENTS: 

497 rm = running_mean 

498 rv = running_var 

499 if rm is None or rv is None: 

500 rm, rv = _get_temp_stats(input.device, input.dtype, feat_dim) 

501 with torch_device_fn.device(input.device): 

502 return torch.ops.aten._native_batch_norm_legit.default( 

503 input, weight, bias, rm, rv, training, momentum, eps 

504 ) 

505 output = torch.empty_like(input_3d) 

506 mean = torch.empty(feat_dim, device=input.device, dtype=input.dtype) 

507 inv_std = torch.empty_like(mean) 

508 

509 need_stats = training or running_mean is None or running_var is None 

510 update_running = training and running_mean is not None and running_var is not None 

511 

512 small_training = total <= _FALLBACK_ELEMENTS and ( 

513 training or running_mean is None or running_var is None 

514 ) 

515 

516 if small_training: 

517 block = _block_size(total) 

518 num_warps = _num_warps(block) 

519 with torch_device_fn.device(input.device): 

520 _bn_fused_train_kernel[(feat_dim,)]( 

521 input_3d, 

522 weight, 

523 bias, 

524 output, 

525 mean, 

526 inv_std, 

527 running_mean if running_mean is not None else mean, 

528 running_var if running_var is not None else inv_std, 

529 batch_dim, 

530 spatial_dim, 

531 *input_3d.stride(), 

532 *output.stride(), 

533 momentum, 

534 eps, 

535 update_running=update_running, 

536 BLOCK=block, 

537 num_warps=num_warps, 

538 ) 

539 return output.view_as(input), mean, inv_std 

540 

541 block, num_blocks, num_warps = _get_launch_config(batch_dim, spatial_dim) 

542 

543 with torch_device_fn.device(input.device): 

544 if need_stats: 

545 partial_shape = (feat_dim, num_blocks) 

546 partial_sum = torch.empty( 

547 partial_shape, device=input.device, dtype=torch.float32 

548 ) 

549 partial_sq = torch.empty_like(partial_sum) 

550 

551 _bn_forward_stats_stage1[(feat_dim, num_blocks)]( 

552 input_3d, 

553 partial_sum, 

554 partial_sq, 

555 batch_dim, 

556 spatial_dim, 

557 *input_3d.stride(), 

558 num_blocks, 

559 BLOCK=block, 

560 num_warps=num_warps, 

561 ) 

562 

563 if num_blocks == 1: 

564 sum_buf = partial_sum[:, 0].contiguous() 

565 sum_sq_buf = partial_sq[:, 0].contiguous() 

566 else: 

567 sum_buf = torch.zeros( 

568 (feat_dim,), device=input.device, dtype=torch.float32 

569 ) 

570 sum_sq_buf = torch.zeros_like(sum_buf) 

571 reduce_blocks = triton.cdiv(num_blocks, _REDUCE_BLOCK) 

572 _bn_reduce_partial_kernel[(feat_dim, reduce_blocks)]( 

573 partial_sum, 

574 partial_sq, 

575 sum_buf, 

576 sum_sq_buf, 

577 num_blocks, 

578 BLOCK=_REDUCE_BLOCK, 

579 num_warps=_num_warps(_REDUCE_BLOCK), 

580 ) 

581 

582 _bn_forward_finalize_kernel[(feat_dim,)]( 

583 sum_buf, 

584 sum_sq_buf, 

585 mean, 

586 inv_std, 

587 running_mean, 

588 running_var, 

589 total, 

590 momentum, 

591 eps, 

592 update_running=update_running, 

593 num_warps=1, 

594 ) 

595 else: 

596 if running_mean is None or running_var is None: 

597 raise RuntimeError( 

598 "running_mean and running_var are required in eval mode" 

599 ) 

600 mean.copy_(running_mean) 

601 inv_std.copy_((running_var + eps).rsqrt()) 

602 

603 _bn_forward_apply_kernel[(feat_dim, num_blocks)]( 

604 input_3d, 

605 weight, 

606 bias, 

607 mean, 

608 inv_std, 

609 output, 

610 batch_dim, 

611 spatial_dim, 

612 *input_3d.stride(), 

613 *output.stride(), 

614 BLOCK=block, 

615 num_warps=num_warps, 

616 ) 

617 

618 return output.view_as(input), mean, inv_std 

619 

620 

621def batch_norm_backward( 

622 grad_out, 

623 input, 

624 weight=None, 

625 running_mean=None, 

626 running_var=None, 

627 save_mean=None, 

628 save_invstd=None, 

629 train: bool = False, 

630 eps: float = 1e-05, 

631 output_mask=None, 

632): 

633 logger.debug("GEMS_MTHREADS BATCHNORM BACKWARD") 

634 

635 input_3d = _make_3d_for_bn(input) 

636 output_grad_3d = _make_3d_for_bn(grad_out) 

637 batch_dim, feat_dim, spatial_dim = input_3d.shape 

638 

639 if output_mask[0]: 

640 input_grad = torch.empty_like(input_3d) 

641 else: 

642 input_grad = None 

643 if output_mask[1]: 

644 weight_grad = torch.empty((feat_dim,), dtype=input.dtype, device=input.device) 

645 else: 

646 weight_grad = None 

647 if output_mask[2]: 

648 bias_grad = torch.empty((feat_dim,), dtype=input.dtype, device=input.device) 

649 else: 

650 bias_grad = None 

651 

652 block, num_blocks, num_warps = _get_launch_config(batch_dim, spatial_dim) 

653 

654 with torch_device_fn.device(input.device): 

655 partial_shape = (feat_dim, num_blocks) 

656 partial_sum = torch.empty( 

657 partial_shape, device=input.device, dtype=torch.float32 

658 ) 

659 partial_sum_xhat = torch.empty_like(partial_sum) 

660 

661 _bn_backward_reduce_kernel[(feat_dim, num_blocks)]( 

662 output_grad_3d, 

663 input_3d, 

664 save_mean, 

665 save_invstd, 

666 partial_sum, 

667 partial_sum_xhat, 

668 batch_dim, 

669 spatial_dim, 

670 *output_grad_3d.stride(), 

671 *input_3d.stride(), 

672 num_blocks, 

673 BLOCK=block, 

674 num_warps=num_warps, 

675 ) 

676 

677 if num_blocks == 1: 

678 sum_dy = partial_sum[:, 0].contiguous() 

679 sum_dy_xhat = partial_sum_xhat[:, 0].contiguous() 

680 else: 

681 sum_dy = torch.zeros((feat_dim,), device=input.device, dtype=torch.float32) 

682 sum_dy_xhat = torch.zeros_like(sum_dy) 

683 reduce_blocks = triton.cdiv(num_blocks, _REDUCE_BLOCK) 

684 _bn_backward_reduce_partial_kernel[(feat_dim, reduce_blocks)]( 

685 partial_sum, 

686 partial_sum_xhat, 

687 sum_dy, 

688 sum_dy_xhat, 

689 num_blocks, 

690 BLOCK=_REDUCE_BLOCK, 

691 num_warps=_num_warps(_REDUCE_BLOCK), 

692 ) 

693 

694 if output_mask[0]: 

695 _bn_backward_input_kernel[(feat_dim, num_blocks)]( 

696 output_grad_3d, 

697 input_3d, 

698 save_mean, 

699 save_invstd, 

700 weight, 

701 sum_dy, 

702 sum_dy_xhat, 

703 input_grad, 

704 batch_dim, 

705 spatial_dim, 

706 *output_grad_3d.stride(), 

707 *input_3d.stride(), 

708 *input_grad.stride(), 

709 BLOCK=block, 

710 num_warps=num_warps, 

711 ) 

712 

713 if output_mask[1] or output_mask[2]: 

714 _bn_backward_param_kernel[(feat_dim,)]( 

715 sum_dy, 

716 sum_dy_xhat, 

717 weight_grad if weight_grad is not None else sum_dy, 

718 bias_grad if bias_grad is not None else sum_dy, 

719 weight_grad_mask=output_mask[1], 

720 bias_grad_mask=output_mask[2], 

721 num_warps=1, 

722 ) 

723 

724 return ( 

725 input_grad.view_as(input) if input_grad is not None else None, 

726 weight_grad, 

727 bias_grad, 

728 )