Coverage for src/flag_gems/runtime/backend/_mthreads/ops/batch_norm.py: 0%
283 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
1import logging
2from typing import Tuple
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import tl_extra_shim
11logger = logging.getLogger(
12 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}'
13)
14rsqrt = tl_extra_shim.rsqrt
17def _make_3d_for_bn(input: torch.Tensor) -> torch.Tensor:
18 if input.ndim == 2:
19 return input.unsqueeze(-1)
20 if input.ndim >= 4:
21 return input.flatten(2, -1)
22 return input
25def _block_size(numel: int) -> int:
26 if numel >= 524288:
27 return 512
28 if numel >= 1024:
29 return 256
30 if numel >= 256:
31 return 128
32 return 64
35def _num_warps(block: int) -> int:
36 if block >= 512:
37 return 8
38 if block >= 256:
39 return 4
40 if block >= 128:
41 return 4
42 return 2
45_REDUCE_BLOCK = 256
46_FALLBACK_ELEMENTS = 4096
47_NATIVE_SWITCH_ELEMENTS = 32768
48_NATIVE_CACHE = {}
51def _get_temp_stats(device, dtype, feat_dim):
52 key = (device, dtype, feat_dim)
53 cached = _NATIVE_CACHE.get(key)
54 if cached is None or cached[0].numel() != feat_dim:
55 rm = torch.zeros((feat_dim,), device=device, dtype=dtype)
56 rv = torch.ones((feat_dim,), device=device, dtype=dtype)
57 _NATIVE_CACHE[key] = (rm, rv)
58 return _NATIVE_CACHE[key]
61@triton.jit
62def _bn_forward_stats_stage1(
63 input_ptr,
64 partial_sum_ptr,
65 partial_sq_ptr,
66 batch_dim,
67 spatial_dim,
68 input_batch_stride,
69 input_feat_stride,
70 input_spatial_stride,
71 num_blocks,
72 BLOCK: tl.constexpr,
73):
74 feat = tl.program_id(0)
75 block_id = tl.program_id(1)
77 offset = block_id * BLOCK + tl.arange(0, BLOCK)
78 total = batch_dim * spatial_dim
79 mask = offset < total
81 batch_idx = offset // spatial_dim
82 spatial_idx = offset - batch_idx * spatial_dim
84 ptrs = (
85 input_ptr
86 + feat * input_feat_stride
87 + batch_idx * input_batch_stride
88 + spatial_idx * input_spatial_stride
89 )
90 values = tl.load(ptrs, mask=mask, other=0.0).to(tl.float32)
92 tl.store(partial_sum_ptr + feat * num_blocks + block_id, tl.sum(values, axis=0))
93 tl.store(
94 partial_sq_ptr + feat * num_blocks + block_id,
95 tl.sum(values * values, axis=0),
96 )
99@triton.jit
100def _bn_reduce_partial_kernel(
101 partial_sum_ptr,
102 partial_sq_ptr,
103 sum_ptr,
104 sum_sq_ptr,
105 num_blocks,
106 BLOCK: tl.constexpr,
107):
108 feat = tl.program_id(0)
109 block_id = tl.program_id(1)
111 offset = block_id * BLOCK + tl.arange(0, BLOCK)
112 mask = offset < num_blocks
114 partial_sum = tl.load(
115 partial_sum_ptr + feat * num_blocks + offset, mask=mask, other=0.0
116 )
117 partial_sq = tl.load(
118 partial_sq_ptr + feat * num_blocks + offset, mask=mask, other=0.0
119 )
121 tl.atomic_add(sum_ptr + feat, tl.sum(partial_sum, axis=0))
122 tl.atomic_add(sum_sq_ptr + feat, tl.sum(partial_sq, axis=0))
125@triton.jit
126def _bn_fused_train_kernel(
127 input_ptr,
128 weight_ptr,
129 bias_ptr,
130 output_ptr,
131 mean_ptr,
132 inv_std_ptr,
133 running_mean_ptr,
134 running_var_ptr,
135 batch_dim,
136 spatial_dim,
137 input_batch_stride,
138 input_feat_stride,
139 input_spatial_stride,
140 output_batch_stride,
141 output_feat_stride,
142 output_spatial_stride,
143 momentum,
144 eps,
145 update_running: tl.constexpr,
146 BLOCK: tl.constexpr,
147):
148 feat = tl.program_id(0)
149 offsets = tl.arange(0, BLOCK)
150 total = batch_dim * spatial_dim
151 num_tiles = tl.cdiv(total, BLOCK)
153 sum_val = tl.zeros((), dtype=tl.float32)
154 sum_sq_val = tl.zeros((), dtype=tl.float32)
156 for tile in range(0, num_tiles):
157 idx = tile * BLOCK + offsets
158 mask = idx < total
159 batch_idx = idx // spatial_dim
160 spatial_idx = idx - batch_idx * spatial_dim
161 ptrs = (
162 input_ptr
163 + feat * input_feat_stride
164 + batch_idx * input_batch_stride
165 + spatial_idx * input_spatial_stride
166 )
167 vals = tl.load(ptrs, mask=mask, other=0.0).to(tl.float32)
168 sum_val += tl.sum(vals, axis=0)
169 sum_sq_val += tl.sum(vals * vals, axis=0)
171 total_f = tl.full((), total, tl.float32)
172 mean = sum_val / total_f
173 var = tl.maximum(sum_sq_val / total_f - mean * mean, 0.0)
174 inv_std = rsqrt(var + eps)
176 tl.store(mean_ptr + feat, mean)
177 tl.store(inv_std_ptr + feat, inv_std)
179 if update_running:
180 running_mean = tl.load(running_mean_ptr + feat)
181 running_var = tl.load(running_var_ptr + feat)
182 unbiased_var = var * total_f / tl.maximum(total_f - 1, 1.0)
183 tl.store(
184 running_mean_ptr + feat, (1 - momentum) * running_mean + momentum * mean
185 )
186 tl.store(
187 running_var_ptr + feat,
188 (1 - momentum) * running_var + momentum * unbiased_var,
189 )
191 weight = tl.load(weight_ptr + feat).to(tl.float32) if weight_ptr else 1.0
192 bias = tl.load(bias_ptr + feat).to(tl.float32) if bias_ptr else 0.0
194 for tile in range(0, num_tiles):
195 idx = tile * BLOCK + offsets
196 mask = idx < total
197 batch_idx = idx // spatial_dim
198 spatial_idx = idx - batch_idx * spatial_dim
199 input_ptrs = (
200 input_ptr
201 + feat * input_feat_stride
202 + batch_idx * input_batch_stride
203 + spatial_idx * input_spatial_stride
204 )
205 output_ptrs = (
206 output_ptr
207 + feat * output_feat_stride
208 + batch_idx * output_batch_stride
209 + spatial_idx * output_spatial_stride
210 )
211 vals = tl.load(input_ptrs, mask=mask, other=0.0).to(tl.float32)
212 out = (vals - mean) * inv_std * weight + bias
213 tl.store(output_ptrs, out, mask=mask)
216@triton.jit
217def _bn_forward_finalize_kernel(
218 sum_ptr,
219 sum_sq_ptr,
220 mean_ptr,
221 inv_std_ptr,
222 running_mean_ptr,
223 running_var_ptr,
224 total_elems,
225 momentum,
226 eps,
227 update_running: tl.constexpr,
228):
229 feat = tl.program_id(0)
230 sum_val = tl.load(sum_ptr + feat)
231 sum_sq_val = tl.load(sum_sq_ptr + feat)
233 total = tl.full((), total_elems, tl.float32)
234 mean = sum_val / total
235 var = tl.maximum(sum_sq_val / total - mean * mean, 0.0)
236 inv_std = rsqrt(var + eps)
238 tl.store(mean_ptr + feat, mean)
239 tl.store(inv_std_ptr + feat, inv_std)
241 if update_running:
242 if running_mean_ptr and running_var_ptr:
243 running_mean = tl.load(running_mean_ptr + feat)
244 running_var = tl.load(running_var_ptr + feat)
245 unbiased_var = var * total / tl.maximum(total - 1.0, 1.0)
246 tl.store(
247 running_mean_ptr + feat,
248 (1 - momentum) * running_mean + momentum * mean,
249 )
250 tl.store(
251 running_var_ptr + feat,
252 (1 - momentum) * running_var + momentum * unbiased_var,
253 )
256@triton.jit
257def _bn_forward_apply_kernel(
258 input_ptr,
259 weight_ptr,
260 bias_ptr,
261 mean_ptr,
262 inv_std_ptr,
263 output_ptr,
264 batch_dim,
265 spatial_dim,
266 input_batch_stride,
267 input_feat_stride,
268 input_spatial_stride,
269 output_batch_stride,
270 output_feat_stride,
271 output_spatial_stride,
272 BLOCK: tl.constexpr,
273):
274 feat = tl.program_id(0)
275 block_id = tl.program_id(1)
277 offset = block_id * BLOCK + tl.arange(0, BLOCK)
278 total = batch_dim * spatial_dim
279 mask = offset < total
281 batch_idx = offset // spatial_dim
282 spatial_idx = offset - batch_idx * spatial_dim
284 mean = tl.load(mean_ptr + feat).to(tl.float32)
285 inv_std = tl.load(inv_std_ptr + feat).to(tl.float32)
287 weight = tl.load(weight_ptr + feat).to(tl.float32) if weight_ptr else 1.0
288 bias = tl.load(bias_ptr + feat).to(tl.float32) if bias_ptr else 0.0
290 input_ptrs = (
291 input_ptr
292 + feat * input_feat_stride
293 + batch_idx * input_batch_stride
294 + spatial_idx * input_spatial_stride
295 )
296 output_ptrs = (
297 output_ptr
298 + feat * output_feat_stride
299 + batch_idx * output_batch_stride
300 + spatial_idx * output_spatial_stride
301 )
303 values = tl.load(input_ptrs, mask=mask, other=0.0).to(tl.float32)
304 output = (values - mean) * inv_std * weight + bias
305 tl.store(output_ptrs, output, mask=mask)
308@triton.jit
309def _bn_backward_reduce_kernel(
310 output_grad_ptr,
311 input_ptr,
312 mean_ptr,
313 inv_std_ptr,
314 partial_sum_ptr,
315 partial_sum_xhat_ptr,
316 batch_dim,
317 spatial_dim,
318 output_grad_batch_stride,
319 output_grad_feat_stride,
320 output_grad_spatial_stride,
321 input_batch_stride,
322 input_feat_stride,
323 input_spatial_stride,
324 num_blocks,
325 BLOCK: tl.constexpr,
326):
327 feat = tl.program_id(0)
328 block_id = tl.program_id(1)
330 offset = block_id * BLOCK + tl.arange(0, BLOCK)
331 total = batch_dim * spatial_dim
332 mask = offset < total
334 batch_idx = offset // spatial_dim
335 spatial_idx = offset - batch_idx * spatial_dim
337 mean = tl.load(mean_ptr + feat).to(tl.float32)
338 inv_std = tl.load(inv_std_ptr + feat).to(tl.float32)
340 grad_ptrs = (
341 output_grad_ptr
342 + feat * output_grad_feat_stride
343 + batch_idx * output_grad_batch_stride
344 + spatial_idx * output_grad_spatial_stride
345 )
346 input_ptrs = (
347 input_ptr
348 + feat * input_feat_stride
349 + batch_idx * input_batch_stride
350 + spatial_idx * input_spatial_stride
351 )
353 dy = tl.load(grad_ptrs, mask=mask, other=0.0).to(tl.float32)
354 x = tl.load(input_ptrs, mask=mask, other=0.0).to(tl.float32)
355 x_hat = (x - mean) * inv_std
357 tl.store(partial_sum_ptr + feat * num_blocks + block_id, tl.sum(dy, axis=0))
358 tl.store(
359 partial_sum_xhat_ptr + feat * num_blocks + block_id,
360 tl.sum(dy * x_hat, axis=0),
361 )
364@triton.jit
365def _bn_backward_reduce_partial_kernel(
366 partial_sum_ptr,
367 partial_sum_xhat_ptr,
368 sum_dy_ptr,
369 sum_dy_xhat_ptr,
370 num_blocks,
371 BLOCK: tl.constexpr,
372):
373 feat = tl.program_id(0)
374 block_id = tl.program_id(1)
376 offset = block_id * BLOCK + tl.arange(0, BLOCK)
377 mask = offset < num_blocks
379 partial_sum = tl.load(
380 partial_sum_ptr + feat * num_blocks + offset, mask=mask, other=0.0
381 )
382 partial_sum_xhat = tl.load(
383 partial_sum_xhat_ptr + feat * num_blocks + offset, mask=mask, other=0.0
384 )
386 tl.atomic_add(sum_dy_ptr + feat, tl.sum(partial_sum, axis=0))
387 tl.atomic_add(sum_dy_xhat_ptr + feat, tl.sum(partial_sum_xhat, axis=0))
390@triton.jit
391def _bn_backward_input_kernel(
392 output_grad_ptr,
393 input_ptr,
394 mean_ptr,
395 inv_std_ptr,
396 weight_ptr,
397 sum_dy_ptr,
398 sum_dy_xhat_ptr,
399 input_grad_ptr,
400 batch_dim,
401 spatial_dim,
402 output_grad_batch_stride,
403 output_grad_feat_stride,
404 output_grad_spatial_stride,
405 input_batch_stride,
406 input_feat_stride,
407 input_spatial_stride,
408 input_grad_batch_stride,
409 input_grad_feat_stride,
410 input_grad_spatial_stride,
411 BLOCK: tl.constexpr,
412):
413 feat = tl.program_id(0)
414 block_id = tl.program_id(1)
416 offset = block_id * BLOCK + tl.arange(0, BLOCK)
417 total = batch_dim * spatial_dim
418 mask = offset < total
420 batch_idx = offset // spatial_dim
421 spatial_idx = offset - batch_idx * spatial_dim
423 mean = tl.load(mean_ptr + feat).to(tl.float32)
424 inv_std = tl.load(inv_std_ptr + feat).to(tl.float32)
425 sum_dy = tl.load(sum_dy_ptr + feat)
426 sum_dy_xhat = tl.load(sum_dy_xhat_ptr + feat)
427 count = tl.full((), total, tl.float32)
429 weight = tl.load(weight_ptr + feat).to(tl.float32) if weight_ptr else 1.0
431 grad_ptrs = (
432 output_grad_ptr
433 + feat * output_grad_feat_stride
434 + batch_idx * output_grad_batch_stride
435 + spatial_idx * output_grad_spatial_stride
436 )
437 input_ptrs = (
438 input_ptr
439 + feat * input_feat_stride
440 + batch_idx * input_batch_stride
441 + spatial_idx * input_spatial_stride
442 )
443 input_grad_ptrs = (
444 input_grad_ptr
445 + feat * input_grad_feat_stride
446 + batch_idx * input_grad_batch_stride
447 + spatial_idx * input_grad_spatial_stride
448 )
450 dy = tl.load(grad_ptrs, mask=mask, other=0.0).to(tl.float32)
451 x = tl.load(input_ptrs, mask=mask, other=0.0).to(tl.float32)
452 x_hat = (x - mean) * inv_std
454 term = (dy - sum_dy / count - x_hat * sum_dy_xhat / count) * inv_std * weight
455 tl.store(input_grad_ptrs, term, mask=mask)
458@triton.jit
459def _bn_backward_param_kernel(
460 sum_dy_ptr,
461 sum_dy_xhat_ptr,
462 weight_grad_ptr,
463 bias_grad_ptr,
464 weight_grad_mask: tl.constexpr,
465 bias_grad_mask: tl.constexpr,
466):
467 feat = tl.program_id(0)
468 if weight_grad_mask:
469 tl.store(weight_grad_ptr + feat, tl.load(sum_dy_xhat_ptr + feat))
470 if bias_grad_mask:
471 tl.store(bias_grad_ptr + feat, tl.load(sum_dy_ptr + feat))
474def _get_launch_config(batch_dim: int, spatial_dim: int) -> Tuple[int, int, int]:
475 total = batch_dim * spatial_dim
476 block = _block_size(total)
477 num_blocks = triton.cdiv(total, block)
478 return block, num_blocks, _num_warps(block)
481def batch_norm(
482 input: torch.Tensor,
483 weight=None,
484 bias=None,
485 running_mean=None,
486 running_var=None,
487 training: bool = False,
488 momentum: float = 0.1,
489 eps: float = 1e-05,
490):
491 logger.debug("GEMS_MTHREADS BATCHNORM FORWARD")
492 input_3d = _make_3d_for_bn(input)
493 batch_dim, feat_dim, spatial_dim = input_3d.shape
494 total = batch_dim * spatial_dim
496 if total <= _NATIVE_SWITCH_ELEMENTS:
497 rm = running_mean
498 rv = running_var
499 if rm is None or rv is None:
500 rm, rv = _get_temp_stats(input.device, input.dtype, feat_dim)
501 with torch_device_fn.device(input.device):
502 return torch.ops.aten._native_batch_norm_legit.default(
503 input, weight, bias, rm, rv, training, momentum, eps
504 )
505 output = torch.empty_like(input_3d)
506 mean = torch.empty(feat_dim, device=input.device, dtype=input.dtype)
507 inv_std = torch.empty_like(mean)
509 need_stats = training or running_mean is None or running_var is None
510 update_running = training and running_mean is not None and running_var is not None
512 small_training = total <= _FALLBACK_ELEMENTS and (
513 training or running_mean is None or running_var is None
514 )
516 if small_training:
517 block = _block_size(total)
518 num_warps = _num_warps(block)
519 with torch_device_fn.device(input.device):
520 _bn_fused_train_kernel[(feat_dim,)](
521 input_3d,
522 weight,
523 bias,
524 output,
525 mean,
526 inv_std,
527 running_mean if running_mean is not None else mean,
528 running_var if running_var is not None else inv_std,
529 batch_dim,
530 spatial_dim,
531 *input_3d.stride(),
532 *output.stride(),
533 momentum,
534 eps,
535 update_running=update_running,
536 BLOCK=block,
537 num_warps=num_warps,
538 )
539 return output.view_as(input), mean, inv_std
541 block, num_blocks, num_warps = _get_launch_config(batch_dim, spatial_dim)
543 with torch_device_fn.device(input.device):
544 if need_stats:
545 partial_shape = (feat_dim, num_blocks)
546 partial_sum = torch.empty(
547 partial_shape, device=input.device, dtype=torch.float32
548 )
549 partial_sq = torch.empty_like(partial_sum)
551 _bn_forward_stats_stage1[(feat_dim, num_blocks)](
552 input_3d,
553 partial_sum,
554 partial_sq,
555 batch_dim,
556 spatial_dim,
557 *input_3d.stride(),
558 num_blocks,
559 BLOCK=block,
560 num_warps=num_warps,
561 )
563 if num_blocks == 1:
564 sum_buf = partial_sum[:, 0].contiguous()
565 sum_sq_buf = partial_sq[:, 0].contiguous()
566 else:
567 sum_buf = torch.zeros(
568 (feat_dim,), device=input.device, dtype=torch.float32
569 )
570 sum_sq_buf = torch.zeros_like(sum_buf)
571 reduce_blocks = triton.cdiv(num_blocks, _REDUCE_BLOCK)
572 _bn_reduce_partial_kernel[(feat_dim, reduce_blocks)](
573 partial_sum,
574 partial_sq,
575 sum_buf,
576 sum_sq_buf,
577 num_blocks,
578 BLOCK=_REDUCE_BLOCK,
579 num_warps=_num_warps(_REDUCE_BLOCK),
580 )
582 _bn_forward_finalize_kernel[(feat_dim,)](
583 sum_buf,
584 sum_sq_buf,
585 mean,
586 inv_std,
587 running_mean,
588 running_var,
589 total,
590 momentum,
591 eps,
592 update_running=update_running,
593 num_warps=1,
594 )
595 else:
596 if running_mean is None or running_var is None:
597 raise RuntimeError(
598 "running_mean and running_var are required in eval mode"
599 )
600 mean.copy_(running_mean)
601 inv_std.copy_((running_var + eps).rsqrt())
603 _bn_forward_apply_kernel[(feat_dim, num_blocks)](
604 input_3d,
605 weight,
606 bias,
607 mean,
608 inv_std,
609 output,
610 batch_dim,
611 spatial_dim,
612 *input_3d.stride(),
613 *output.stride(),
614 BLOCK=block,
615 num_warps=num_warps,
616 )
618 return output.view_as(input), mean, inv_std
621def batch_norm_backward(
622 grad_out,
623 input,
624 weight=None,
625 running_mean=None,
626 running_var=None,
627 save_mean=None,
628 save_invstd=None,
629 train: bool = False,
630 eps: float = 1e-05,
631 output_mask=None,
632):
633 logger.debug("GEMS_MTHREADS BATCHNORM BACKWARD")
635 input_3d = _make_3d_for_bn(input)
636 output_grad_3d = _make_3d_for_bn(grad_out)
637 batch_dim, feat_dim, spatial_dim = input_3d.shape
639 if output_mask[0]:
640 input_grad = torch.empty_like(input_3d)
641 else:
642 input_grad = None
643 if output_mask[1]:
644 weight_grad = torch.empty((feat_dim,), dtype=input.dtype, device=input.device)
645 else:
646 weight_grad = None
647 if output_mask[2]:
648 bias_grad = torch.empty((feat_dim,), dtype=input.dtype, device=input.device)
649 else:
650 bias_grad = None
652 block, num_blocks, num_warps = _get_launch_config(batch_dim, spatial_dim)
654 with torch_device_fn.device(input.device):
655 partial_shape = (feat_dim, num_blocks)
656 partial_sum = torch.empty(
657 partial_shape, device=input.device, dtype=torch.float32
658 )
659 partial_sum_xhat = torch.empty_like(partial_sum)
661 _bn_backward_reduce_kernel[(feat_dim, num_blocks)](
662 output_grad_3d,
663 input_3d,
664 save_mean,
665 save_invstd,
666 partial_sum,
667 partial_sum_xhat,
668 batch_dim,
669 spatial_dim,
670 *output_grad_3d.stride(),
671 *input_3d.stride(),
672 num_blocks,
673 BLOCK=block,
674 num_warps=num_warps,
675 )
677 if num_blocks == 1:
678 sum_dy = partial_sum[:, 0].contiguous()
679 sum_dy_xhat = partial_sum_xhat[:, 0].contiguous()
680 else:
681 sum_dy = torch.zeros((feat_dim,), device=input.device, dtype=torch.float32)
682 sum_dy_xhat = torch.zeros_like(sum_dy)
683 reduce_blocks = triton.cdiv(num_blocks, _REDUCE_BLOCK)
684 _bn_backward_reduce_partial_kernel[(feat_dim, reduce_blocks)](
685 partial_sum,
686 partial_sum_xhat,
687 sum_dy,
688 sum_dy_xhat,
689 num_blocks,
690 BLOCK=_REDUCE_BLOCK,
691 num_warps=_num_warps(_REDUCE_BLOCK),
692 )
694 if output_mask[0]:
695 _bn_backward_input_kernel[(feat_dim, num_blocks)](
696 output_grad_3d,
697 input_3d,
698 save_mean,
699 save_invstd,
700 weight,
701 sum_dy,
702 sum_dy_xhat,
703 input_grad,
704 batch_dim,
705 spatial_dim,
706 *output_grad_3d.stride(),
707 *input_3d.stride(),
708 *input_grad.stride(),
709 BLOCK=block,
710 num_warps=num_warps,
711 )
713 if output_mask[1] or output_mask[2]:
714 _bn_backward_param_kernel[(feat_dim,)](
715 sum_dy,
716 sum_dy_xhat,
717 weight_grad if weight_grad is not None else sum_dy,
718 bias_grad if bias_grad is not None else sum_dy,
719 weight_grad_mask=output_mask[1],
720 bias_grad_mask=output_mask[2],
721 num_warps=1,
722 )
724 return (
725 input_grad.view_as(input) if input_grad is not None else None,
726 weight_grad,
727 bias_grad,
728 )