Coverage for src/flag_gems/ops/batch_norm.py: 35%
154 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
6from torch import Tensor
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry, tl_extra_shim
12logger = logging.getLogger(__name__)
13rsqrt = tl_extra_shim.rsqrt
16def make_3d_for_bn(input: Tensor) -> Tensor:
17 """
18 Converts the input to a 3D view for batch normalization.
20 Args:
21 input: Input to render 3D.
23 Returns:
24 Input's 3D view.
25 """
26 if input.ndim == 2:
27 input = input.unsqueeze(-1)
29 elif input.ndim >= 4:
30 input = input.flatten(2, -1)
32 return input
35# NOTE: This part of the kernel code is copied and modified
36# from the https://github.com/BobMcDear/attorch codebase.
39@libentry()
40@triton.autotune(
41 configs=runtime.get_tuned_config("batch_norm"),
42 key=["batch_dim", "spatial_dim"],
43 restore_value=["running_mean_pointer", "running_var_pointer"],
44)
45@triton.heuristics(runtime.get_heuristic_config("batch_norm"))
46@triton.jit
47def batch_norm_forward_kernel(
48 input_pointer,
49 weight_pointer,
50 bias_pointer,
51 mean_pointer,
52 inv_std_pointer,
53 output_pointer,
54 running_mean_pointer,
55 running_var_pointer,
56 batch_dim,
57 spatial_dim,
58 input_batch_stride,
59 input_feat_stride,
60 input_spatial_stride,
61 output_batch_stride,
62 output_feat_stride,
63 output_spatial_stride,
64 momentum,
65 eps,
66 is_train: tl.constexpr,
67 BLOCK_M: tl.constexpr,
68 BLOCK_N: tl.constexpr,
69):
70 feat_pid = tl.program_id(axis=0)
72 # traning mode default track_running_stat
73 if is_train:
74 mean = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
75 var = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
76 cnt = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
78 m_num_steps = tl.cdiv(batch_dim, BLOCK_M)
79 n_num_steps = tl.cdiv(spatial_dim, BLOCK_N)
81 for m_step in range(0, m_num_steps):
82 for n_step in range(0, n_num_steps):
83 spatial_offset = n_step * BLOCK_N + tl.arange(0, BLOCK_N)
84 spatial_mask = spatial_offset < spatial_dim
86 batch_offset = m_step * BLOCK_M + tl.arange(0, BLOCK_M)
87 batch_mask = batch_offset < batch_dim
89 curr_input_pointer = (
90 input_pointer
91 + input_feat_stride * feat_pid
92 + input_batch_stride * batch_offset[:, None]
93 + input_spatial_stride * spatial_offset[None, :]
94 )
96 mask = batch_mask[:, None] & spatial_mask[None, :]
97 curr_input = tl.load(curr_input_pointer, mask=mask).to(tl.float32)
99 step = m_step * n_num_steps + n_step + 1
100 new_mean = tl.where(mask, mean + (curr_input - mean) / step, mean)
101 new_var = tl.where(
102 mask, var + (curr_input - new_mean) * (curr_input - mean), var
103 )
104 cnt += mask.to(tl.int32)
105 mean = new_mean
106 var = new_var
108 final_mean = tl.sum(mean * cnt) / (batch_dim * spatial_dim)
109 var = tl.sum(var + cnt * (mean - final_mean) * (mean - final_mean)) / (
110 batch_dim * spatial_dim
111 )
112 inv_std = rsqrt(var + eps)
113 mean = final_mean
115 tl.store(feat_pid + mean_pointer, mean)
116 tl.store(feat_pid + inv_std_pointer, inv_std)
118 running_mean_pointer += feat_pid
119 running_var_pointer += feat_pid
121 running_mean = tl.load(running_mean_pointer)
122 running_var = tl.load(running_var_pointer)
124 n = batch_dim * spatial_dim
125 tl.store(running_mean_pointer, (1 - momentum) * running_mean + momentum * mean)
126 tl.store(
127 running_var_pointer,
128 (1 - momentum) * running_var + momentum * var * n / (n - 1),
129 )
131 else:
132 mean = tl.load(feat_pid + running_mean_pointer)
133 inv_std = rsqrt(tl.load(feat_pid + running_var_pointer) + eps)
135 if weight_pointer:
136 weight = tl.load(feat_pid + weight_pointer).to(tl.float32)
137 else:
138 weight = 1.0
139 if bias_pointer:
140 bias = tl.load(feat_pid + bias_pointer).to(tl.float32)
141 else:
142 bias = 0.0
144 for m_step in range(0, tl.cdiv(batch_dim, BLOCK_M)):
145 for n_step in range(0, tl.cdiv(spatial_dim, BLOCK_N)):
146 batch_offset = m_step * BLOCK_M + tl.arange(0, BLOCK_M)
147 batch_mask = batch_offset < batch_dim
149 spatial_offset = n_step * BLOCK_N + tl.arange(0, BLOCK_N)
150 spatial_mask = spatial_offset < spatial_dim
152 curr_input_pointer = (
153 input_pointer
154 + input_feat_stride * feat_pid
155 + input_batch_stride * batch_offset[:, None]
156 + input_spatial_stride * spatial_offset[None, :]
157 )
158 curr_output_pointer = (
159 output_pointer
160 + output_feat_stride * feat_pid
161 + output_batch_stride * batch_offset[:, None]
162 + output_spatial_stride * spatial_offset[None, :]
163 )
165 curr_input = tl.load(
166 curr_input_pointer, mask=batch_mask[:, None] & spatial_mask[None, :]
167 ).to(tl.float32)
168 output = weight * (curr_input - mean) * inv_std + bias
170 tl.store(
171 curr_output_pointer,
172 output,
173 mask=batch_mask[:, None] & spatial_mask[None, :],
174 )
177@libentry()
178@triton.autotune(
179 configs=runtime.get_tuned_config("batch_norm"),
180 key=["batch_dim", "spatial_dim"],
181)
182@triton.heuristics(runtime.get_heuristic_config("batch_norm"))
183@triton.jit
184def batch_norm_backward_kernel(
185 output_grad_pointer,
186 input_pointer,
187 mean_pointer,
188 inv_std_pointer,
189 weight_pointer,
190 input_grad_pointer,
191 weight_grad_pointer,
192 bias_grad_pointer,
193 batch_dim,
194 spatial_dim,
195 output_grad_batch_stride,
196 output_grad_feat_stride,
197 output_grad_spatial_stride,
198 input_batch_stride,
199 input_feat_stride,
200 input_spatial_stride,
201 input_grad_batch_stride,
202 input_grad_feat_stride,
203 input_grad_spatial_stride,
204 input_grad_mask: tl.constexpr,
205 weight_grad_mask: tl.constexpr,
206 bias_grad_mask: tl.constexpr,
207 BLOCK_M: tl.constexpr,
208 BLOCK_N: tl.constexpr,
209):
210 feat_pid = tl.program_id(axis=0)
212 mean = tl.load(feat_pid + mean_pointer).to(tl.float32)
213 inv_std = tl.load(feat_pid + inv_std_pointer).to(tl.float32)
215 term1 = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
216 term2 = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
218 for m_step in range(0, tl.cdiv(batch_dim, BLOCK_M)):
219 for n_step in range(0, tl.cdiv(spatial_dim, BLOCK_N)):
220 batch_offset = m_step * BLOCK_M + tl.arange(0, BLOCK_M)
221 batch_mask = batch_offset < batch_dim
223 spatial_offset = n_step * BLOCK_N + tl.arange(0, BLOCK_N)
224 spatial_mask = spatial_offset < spatial_dim
226 curr_output_grad_pointer = (
227 output_grad_pointer
228 + output_grad_feat_stride * feat_pid
229 + output_grad_batch_stride * batch_offset[:, None]
230 + output_grad_spatial_stride * spatial_offset[None, :]
231 )
232 curr_input_pointer = (
233 input_pointer
234 + input_feat_stride * feat_pid
235 + input_batch_stride * batch_offset[:, None]
236 + input_spatial_stride * spatial_offset[None, :]
237 )
239 mask = batch_mask[:, None] & spatial_mask[None, :]
240 curr_input = tl.load(curr_input_pointer, mask=mask).to(tl.float32)
242 curr_pre_lin = (curr_input - mean) * inv_std
243 curr_output_grad = tl.load(curr_output_grad_pointer, mask=mask).to(
244 tl.float32
245 )
247 term1 += curr_pre_lin * curr_output_grad
248 term2 += curr_output_grad
250 term1 = tl.sum(term1)
251 term2 = tl.sum(term2)
253 if weight_grad_mask:
254 tl.store(feat_pid + weight_grad_pointer, term1)
255 if bias_grad_mask:
256 tl.store(feat_pid + bias_grad_pointer, term2)
258 if not input_grad_mask:
259 return
261 if weight_pointer:
262 weight = tl.load(feat_pid + weight_pointer).to(tl.float32)
263 else:
264 weight = 1.0
266 count = batch_dim * spatial_dim
268 for m_step in range(0, tl.cdiv(batch_dim, BLOCK_M)):
269 for n_step in range(0, tl.cdiv(spatial_dim, BLOCK_N)):
270 batch_offset = m_step * BLOCK_M + tl.arange(0, BLOCK_M)
271 batch_mask = batch_offset < batch_dim
273 spatial_offset = n_step * BLOCK_N + tl.arange(0, BLOCK_N)
274 spatial_mask = spatial_offset < spatial_dim
276 curr_output_grad_pointer = (
277 output_grad_pointer
278 + output_grad_feat_stride * feat_pid
279 + output_grad_batch_stride * batch_offset[:, None]
280 + output_grad_spatial_stride * spatial_offset[None, :]
281 )
282 curr_input_pointer = (
283 input_pointer
284 + input_feat_stride * feat_pid
285 + input_batch_stride * batch_offset[:, None]
286 + input_spatial_stride * spatial_offset[None, :]
287 )
288 curr_input_grad_pointer = (
289 input_grad_pointer
290 + input_grad_feat_stride * feat_pid
291 + input_grad_batch_stride * batch_offset[:, None]
292 + input_grad_spatial_stride * spatial_offset[None, :]
293 )
295 curr_input = tl.load(
296 curr_input_pointer, mask=batch_mask[:, None] & spatial_mask[None, :]
297 ).to(tl.float32)
298 curr_pre_lin = (curr_input - mean) * inv_std
299 curr_output_grad = tl.load(
300 curr_output_grad_pointer,
301 mask=batch_mask[:, None] & spatial_mask[None, :],
302 ).to(tl.float32)
303 curr_input_grad = (
304 inv_std
305 * weight
306 * (curr_output_grad - (term1 * curr_pre_lin + term2) / count)
307 )
308 tl.store(
309 curr_input_grad_pointer,
310 curr_input_grad,
311 mask=batch_mask[:, None] & spatial_mask[None, :],
312 )
315def batch_norm(
316 input: Tensor,
317 weight=None,
318 bias=None,
319 running_mean=None, # self.running_mean if not self.training or self.track_running_state else None
320 running_var=None,
321 training=False, # (self.running_mean is None) and (self.running_var is None)
322 momentum=0.1,
323 eps=1e-05,
324):
325 logger.debug("GEMS BATCHNORM FORWARD")
327 input_3d = make_3d_for_bn(input)
329 batch_dim, feat_dim, spatial_dim = input_3d.shape
330 output = torch.empty_like(input_3d)
332 mean = torch.empty(feat_dim, device=input.device, dtype=input.dtype)
333 inv_std = torch.empty(feat_dim, device=input.device, dtype=input.dtype)
335 running_mean = input if running_mean is None else running_mean
336 running_var = input if running_var is None else running_var
338 # Launches 1D grid where each program operates over one feature.
339 with torch_device_fn.device(input.device):
340 batch_norm_forward_kernel[(feat_dim,)](
341 input_3d,
342 weight,
343 bias,
344 mean,
345 inv_std,
346 output,
347 running_mean,
348 running_var,
349 batch_dim,
350 spatial_dim,
351 *input_3d.stride(),
352 *output.stride(),
353 momentum,
354 eps,
355 is_train=training,
356 )
358 return output.view_as(input), mean, inv_std
361def batch_norm_backward(
362 grad_out,
363 input,
364 weight=None,
365 running_mean=None,
366 running_var=None,
367 save_mean=None,
368 save_invstd=None,
369 train=False,
370 eps=1e-05,
371 output_mask=None,
372):
373 logger.debug("GEMS BATCHNORM BACKWARD")
374 input_3d = make_3d_for_bn(input)
375 output_grad_3d = make_3d_for_bn(grad_out)
377 batch_dim, feat_dim, spatial_dim = input_3d.shape
379 if output_mask[0]:
380 input_grad = torch.empty_like(input_3d)
381 else:
382 input_grad = None
383 if output_mask[1]:
384 weight_grad = torch.empty((feat_dim,), dtype=input.dtype, device=input.device)
385 else:
386 weight_grad = None
387 if output_mask[2]:
388 bias_grad = torch.empty((feat_dim,), dtype=input.dtype, device=input.device)
389 else:
390 bias_grad = None
392 # Launches 1D grid where each program operates over one feature.
393 with torch_device_fn.device(input.device):
394 batch_norm_backward_kernel[(feat_dim,)](
395 output_grad_3d,
396 input_3d,
397 save_mean,
398 save_invstd,
399 weight,
400 input_grad,
401 weight_grad,
402 bias_grad,
403 batch_dim,
404 spatial_dim,
405 *output_grad_3d.stride(),
406 *input_3d.stride(),
407 *input_grad.stride(),
408 *output_mask,
409 )
411 # Pads output with None because a gradient is necessary for
412 # all input arguments.
413 return (
414 input_grad.view_as(input),
415 weight_grad,
416 bias_grad,
417 )