Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/batch_norm.py: 0%
168 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
6from torch import Tensor
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry, tl_extra_shim
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
13rsqrt = tl_extra_shim.rsqrt
16def make_3d_for_bn(input: Tensor) -> Tensor:
17 """
18 Converts the input to a 3D view for batch normalization.
20 Args:
21 input: Input to render 3D.
23 Returns:
24 Input's 3D view.
25 """
26 if input.ndim == 2:
27 input = input.unsqueeze(-1)
29 elif input.ndim >= 4:
30 input = input.flatten(2, -1)
32 return input
35# NOTE: This part of the kernel code is copied and modified
36# from the https://github.com/BobMcDear/attorch codebase.
39@libentry()
40# @triton.autotune(
41# configs=runtime.get_tuned_config("batch_norm"),
42# key=["batch_dim", "spatial_dim"],
43# restore_value=["running_mean_pointer", "running_var_pointer"],
44# )
45@triton.heuristics(runtime.get_heuristic_config("batch_norm"))
46@triton.jit
47def batch_norm_forward_kernel(
48 input_pointer,
49 weight_pointer,
50 bias_pointer,
51 mean_pointer,
52 inv_std_pointer,
53 output_pointer,
54 running_mean_pointer,
55 running_var_pointer,
56 batch_dim,
57 spatial_dim,
58 input_batch_stride,
59 input_feat_stride,
60 input_spatial_stride,
61 output_batch_stride,
62 output_feat_stride,
63 output_spatial_stride,
64 momentum,
65 eps,
66 is_train: tl.constexpr,
67 BLOCK_M: tl.constexpr,
68 BLOCK_N: tl.constexpr,
69):
70 feat_pid = tl.program_id(axis=0)
72 # traning mode default track_running_stat
73 if is_train:
74 mean = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
75 var = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
76 cnt = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
78 m_num_steps = tl.cdiv(batch_dim, BLOCK_M)
79 n_num_steps = tl.cdiv(spatial_dim, BLOCK_N)
81 for m_step in range(0, m_num_steps):
82 for n_step in range(0, n_num_steps):
83 spatial_offset = n_step * BLOCK_N + tl.arange(0, BLOCK_N)
84 spatial_mask = spatial_offset < spatial_dim
86 batch_offset = m_step * BLOCK_M + tl.arange(0, BLOCK_M)
87 batch_mask = batch_offset < batch_dim
89 curr_input_pointer = (
90 input_pointer
91 + input_feat_stride * feat_pid
92 + input_batch_stride * batch_offset[:, None]
93 + input_spatial_stride * spatial_offset[None, :]
94 )
96 mask = batch_mask[:, None] & spatial_mask[None, :]
97 curr_input = tl.load(curr_input_pointer, mask=mask).to(tl.float32)
99 step = m_step * n_num_steps + n_step + 1
100 new_mean = tl.where(mask, mean + (curr_input - mean) / step, mean)
101 new_var = tl.where(
102 mask, var + (curr_input - new_mean) * (curr_input - mean), var
103 )
104 cnt += mask.to(tl.int32)
105 mean = new_mean
106 var = new_var
108 final_mean = tl.sum(mean * cnt) / (batch_dim * spatial_dim)
109 var = tl.sum(var + cnt * (mean - final_mean) * (mean - final_mean)) / (
110 batch_dim * spatial_dim
111 )
112 inv_std = rsqrt(var + eps)
113 mean = final_mean
115 tl.store(feat_pid + mean_pointer, mean)
116 tl.store(feat_pid + inv_std_pointer, inv_std)
118 running_mean_pointer += feat_pid
119 running_var_pointer += feat_pid
121 running_mean = tl.load(running_mean_pointer)
122 running_var = tl.load(running_var_pointer)
124 n = batch_dim * spatial_dim
125 tl.store(running_mean_pointer, (1 - momentum) * running_mean + momentum * mean)
126 tl.store(
127 running_var_pointer,
128 (1 - momentum) * running_var + momentum * var * n / (n - 1),
129 )
131 else:
132 mean = tl.load(feat_pid + running_mean_pointer)
133 inv_std = rsqrt(tl.load(feat_pid + running_var_pointer) + eps)
135 if weight_pointer:
136 weight = tl.load(feat_pid + weight_pointer).to(tl.float32)
137 else:
138 weight = 1.0
139 if bias_pointer:
140 bias = tl.load(feat_pid + bias_pointer).to(tl.float32)
141 else:
142 bias = 0.0
144 for m_step in range(0, tl.cdiv(batch_dim, BLOCK_M)):
145 for n_step in range(0, tl.cdiv(spatial_dim, BLOCK_N)):
146 batch_offset = m_step * BLOCK_M + tl.arange(0, BLOCK_M)
147 batch_mask = batch_offset < batch_dim
149 spatial_offset = n_step * BLOCK_N + tl.arange(0, BLOCK_N)
150 spatial_mask = spatial_offset < spatial_dim
152 curr_input_pointer = (
153 input_pointer
154 + input_feat_stride * feat_pid
155 + input_batch_stride * batch_offset[:, None]
156 + input_spatial_stride * spatial_offset[None, :]
157 )
158 curr_output_pointer = (
159 output_pointer
160 + output_feat_stride * feat_pid
161 + output_batch_stride * batch_offset[:, None]
162 + output_spatial_stride * spatial_offset[None, :]
163 )
165 curr_input = tl.load(
166 curr_input_pointer, mask=batch_mask[:, None] & spatial_mask[None, :]
167 ).to(tl.float32)
168 output = weight * (curr_input - mean) * inv_std + bias
170 tl.store(
171 curr_output_pointer,
172 output,
173 mask=batch_mask[:, None] & spatial_mask[None, :],
174 )
177def batch_norm_heur_block_m(args):
178 return min(64, triton.next_power_of_2(args["batch_dim"]))
181def batch_norm_heur_block_n(args):
182 # A maximum of 16384 elements are loaded at once.
183 BLOCK_M = batch_norm_heur_block_m(args)
184 BLOCK_N = triton.next_power_of_2(args["spatial_dim"])
185 return min(BLOCK_N, max(1, 2**14 // BLOCK_M))
188@libentry()
189# @triton.autotune(
190# configs=runtime.get_tuned_config("batch_norm"),
191# key=["batch_dim", "spatial_dim"],
192# )
193@triton.heuristics(
194 values={
195 "BLOCK_M": batch_norm_heur_block_m,
196 "BLOCK_N": batch_norm_heur_block_n,
197 },
198)
199# @triton.heuristics(runtime.get_heuristic_config("batch_norm"))
200@triton.jit
201def batch_norm_backward_kernel(
202 output_grad_pointer,
203 input_pointer,
204 mean_pointer,
205 inv_std_pointer,
206 weight_pointer,
207 input_grad_pointer,
208 weight_grad_pointer,
209 bias_grad_pointer,
210 batch_dim,
211 spatial_dim,
212 output_grad_batch_stride,
213 output_grad_feat_stride,
214 output_grad_spatial_stride,
215 input_batch_stride,
216 input_feat_stride,
217 input_spatial_stride,
218 input_grad_batch_stride,
219 input_grad_feat_stride,
220 input_grad_spatial_stride,
221 input_grad_mask: tl.constexpr,
222 weight_grad_mask: tl.constexpr,
223 bias_grad_mask: tl.constexpr,
224 BLOCK_M: tl.constexpr,
225 BLOCK_N: tl.constexpr,
226):
227 feat_pid = tl.program_id(axis=0)
229 mean = tl.load(feat_pid + mean_pointer).to(tl.float32)
230 inv_std = tl.load(feat_pid + inv_std_pointer).to(tl.float32)
232 term1 = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
233 term2 = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
235 for m_step in range(0, tl.cdiv(batch_dim, BLOCK_M)):
236 batch_offset = m_step * BLOCK_M + tl.arange(0, BLOCK_M)
237 batch_mask = batch_offset < batch_dim
239 for n_step in range(0, tl.cdiv(spatial_dim, BLOCK_N)):
240 spatial_offset = n_step * BLOCK_N + tl.arange(0, BLOCK_N)
241 spatial_mask = spatial_offset < spatial_dim
243 curr_output_grad_pointer = (
244 output_grad_pointer
245 + output_grad_feat_stride * feat_pid
246 + output_grad_batch_stride * batch_offset[:, None]
247 + output_grad_spatial_stride * spatial_offset[None, :]
248 )
249 curr_input_pointer = (
250 input_pointer
251 + input_feat_stride * feat_pid
252 + input_batch_stride * batch_offset[:, None]
253 + input_spatial_stride * spatial_offset[None, :]
254 )
256 mask = batch_mask[:, None] & spatial_mask[None, :]
257 curr_input = tl.load(curr_input_pointer, mask=mask, other=0).to(tl.float32)
259 curr_pre_lin = ((curr_input - mean) * inv_std).to(tl.float32)
260 curr_output_grad = tl.load(
261 curr_output_grad_pointer, mask=mask, other=0.0
262 ).to(tl.float32)
264 term1 += curr_pre_lin * curr_output_grad
265 term2 += curr_output_grad
267 term1 = tl.sum(term1)
268 term2 = tl.sum(term2)
270 if weight_grad_mask:
271 tl.store(feat_pid + weight_grad_pointer, term1)
272 if bias_grad_mask:
273 tl.store(feat_pid + bias_grad_pointer, term2)
275 if not input_grad_mask:
276 return
278 if weight_pointer:
279 weight = tl.load(feat_pid + weight_pointer).to(tl.float32)
280 else:
281 weight = 1.0
282 weight = weight.to(tl.float32)
284 count = batch_dim * spatial_dim
286 for m_step in range(0, tl.cdiv(batch_dim, BLOCK_M)):
287 for n_step in range(0, tl.cdiv(spatial_dim, BLOCK_N)):
288 batch_offset = m_step * BLOCK_M + tl.arange(0, BLOCK_M)
289 batch_mask = batch_offset < batch_dim
291 spatial_offset = n_step * BLOCK_N + tl.arange(0, BLOCK_N)
292 spatial_mask = spatial_offset < spatial_dim
294 curr_output_grad_pointer = (
295 output_grad_pointer
296 + output_grad_feat_stride * feat_pid
297 + output_grad_batch_stride * batch_offset[:, None]
298 + output_grad_spatial_stride * spatial_offset[None, :]
299 )
300 curr_input_pointer = (
301 input_pointer
302 + input_feat_stride * feat_pid
303 + input_batch_stride * batch_offset[:, None]
304 + input_spatial_stride * spatial_offset[None, :]
305 )
306 curr_input_grad_pointer = (
307 input_grad_pointer
308 + input_grad_feat_stride * feat_pid
309 + input_grad_batch_stride * batch_offset[:, None]
310 + input_grad_spatial_stride * spatial_offset[None, :]
311 )
313 curr_input = tl.load(
314 curr_input_pointer, mask=batch_mask[:, None] & spatial_mask[None, :]
315 ).to(tl.float32)
316 curr_pre_lin = (curr_input - mean) * inv_std
317 curr_output_grad = tl.load(
318 curr_output_grad_pointer,
319 mask=batch_mask[:, None] & spatial_mask[None, :],
320 ).to(tl.float32)
321 curr_input_grad = (
322 inv_std
323 * weight
324 * (curr_output_grad - (term1 * curr_pre_lin + term2) / count)
325 )
326 tl.store(
327 curr_input_grad_pointer,
328 curr_input_grad,
329 mask=batch_mask[:, None] & spatial_mask[None, :],
330 )
333def batch_norm(
334 input: Tensor,
335 weight=None,
336 bias=None,
337 running_mean=None, # self.running_mean if not self.training or self.track_running_state else None
338 running_var=None,
339 training=False, # (self.running_mean is None) and (self.running_var is None)
340 momentum=0.1,
341 eps=1e-05,
342):
343 logger.debug("GEMS BATCHNORM FORWARD")
345 input_3d_i = make_3d_for_bn(input)
346 m, n, k = input_3d_i.shape
347 input_3d_f = input_3d_i.permute(0, 2, 1).reshape(-1, n)
348 input_3d = make_3d_for_bn(input_3d_f)
349 # input_3d = make_3d_for_bn(input)
351 batch_dim, feat_dim, spatial_dim = input_3d.shape
352 output = torch.empty_like(input_3d)
354 mean = torch.empty(feat_dim, device=input.device, dtype=input.dtype)
355 inv_std = torch.empty(feat_dim, device=input.device, dtype=input.dtype)
357 running_mean = input if running_mean is None else running_mean
358 running_var = input if running_var is None else running_var
360 # Launches 1D grid where each program operates over one feature.
361 with torch_device_fn.device(input.device):
362 batch_norm_forward_kernel[(feat_dim,)](
363 input_3d,
364 weight,
365 bias,
366 mean,
367 inv_std,
368 output,
369 running_mean,
370 running_var,
371 batch_dim,
372 spatial_dim,
373 *input_3d.stride(),
374 *output.stride(),
375 momentum,
376 eps,
377 is_train=training,
378 buffer_size_limit=2048,
379 )
381 output_reshaped = output.reshape(m, k, n).permute(0, 2, 1)
382 return output_reshaped.view_as(input), mean, inv_std
385def batch_norm_backward(
386 grad_out,
387 input,
388 weight=None,
389 running_mean=None,
390 running_var=None,
391 save_mean=None,
392 save_invstd=None,
393 train=False,
394 eps=1e-05,
395 output_mask=None,
396):
397 logger.debug("GEMS BATCHNORM BACKWARD")
398 input_3d_i = make_3d_for_bn(input)
399 m, n, k = input_3d_i.shape
400 input_3d_f = input_3d_i.permute(0, 2, 1).reshape(-1, n)
401 input_3d = make_3d_for_bn(input_3d_f)
403 output_grad_3d_i = make_3d_for_bn(grad_out)
404 output_grad_3d_f = output_grad_3d_i.permute(0, 2, 1).reshape(-1, n)
405 output_grad_3d = make_3d_for_bn(output_grad_3d_f)
407 batch_dim, feat_dim, spatial_dim = input_3d.shape
409 if output_mask[0]:
410 input_grad = torch.empty_like(input_3d)
411 else:
412 input_grad = None
413 if output_mask[1]:
414 weight_grad = torch.empty((feat_dim,), dtype=input.dtype, device=input.device)
415 else:
416 weight_grad = None
417 if output_mask[2]:
418 bias_grad = torch.empty((feat_dim,), dtype=input.dtype, device=input.device)
419 else:
420 bias_grad = None
422 # Launches 1D grid where each program operates over one feature.
423 with torch_device_fn.device(input.device):
424 batch_norm_backward_kernel[(feat_dim, 1, 1)](
425 output_grad_3d,
426 input_3d,
427 save_mean,
428 save_invstd,
429 weight,
430 input_grad,
431 weight_grad,
432 bias_grad,
433 batch_dim,
434 spatial_dim,
435 *output_grad_3d.stride(),
436 *input_3d.stride(),
437 *input_grad.stride(),
438 *output_mask,
439 buffer_size_limit=2048,
440 )
442 # Pads output with None because a gradient is necessary for
443 # all input arguments.
444 return (
445 input_grad.reshape(m, k, n).permute(0, 2, 1).view_as(input),
446 weight_grad,
447 bias_grad,
448 )