Coverage for src/flag_gems/runtime/backend/_cambricon/ops/groupnorm.py: 0%
343 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import libentry, tl_extra_shim
10from ..utils import TOTAL_CORE_NUM
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
13rsqrt = tl_extra_shim.rsqrt
16def group_norm_kernel_opt_prune(configs, named_args, **kwargs):
17 pruned_configs = []
18 hw = kwargs["HW"]
19 num_groups = named_args["num_groups"]
20 all_sizes = []
21 for config in configs:
22 BLOCK_HW_SIZE = config.kwargs["BLOCK_HW_SIZE"]
23 if BLOCK_HW_SIZE not in all_sizes:
24 all_sizes.append(BLOCK_HW_SIZE)
26 for config in configs:
27 BLOCK_HW_SIZE = config.kwargs["BLOCK_HW_SIZE"]
28 SPLIT = config.kwargs["SPLIT"]
29 if (hw > 4096) and (BLOCK_HW_SIZE >= 4096) and (SPLIT <= 1):
30 pruned_configs.append(config)
31 elif (BLOCK_HW_SIZE >= hw) and (SPLIT <= num_groups):
32 not_step_bigger = False
33 for size in all_sizes:
34 if (size < BLOCK_HW_SIZE) and (size > hw):
35 not_step_bigger = True
36 if not not_step_bigger:
37 pruned_configs.append(config)
38 return pruned_configs
41@libentry()
42@triton.autotune(
43 configs=[
44 triton.Config({"SPLIT": s, "BLOCK_HW_SIZE": size}, num_stages=3, num_warps=1)
45 for size in [64, 256, 512, 1024, 2048, 4096, 5120]
46 for s in [1, 4, 6, 8, 16]
47 ],
48 key=["X", "group_size", "C", "HW", "num_groups"],
49 prune_configs_by={"early_config_prune": group_norm_kernel_opt_prune},
50)
51@triton.jit(do_not_specialize=["eps"])
52def group_norm_kernel_opt(
53 X,
54 Y,
55 W,
56 B,
57 Mean,
58 Rstd,
59 group_size,
60 C,
61 num_groups,
62 eps,
63 HW: tl.constexpr,
64 BLOCK_GROUP_SIZE: tl.constexpr,
65 BLOCK_HW_SIZE: tl.constexpr,
66 SPLIT: tl.constexpr,
67):
68 pid = tl.program_id(0)
69 div_v = tl.cdiv(num_groups, SPLIT)
70 div_mod = num_groups % SPLIT
71 split_group = pid % div_v
72 split_n = pid // div_v
73 real_num_elements = group_size * HW
75 group_offset = tl.arange(0, BLOCK_GROUP_SIZE)
76 hw_offset = tl.arange(0, BLOCK_HW_SIZE)
77 if BLOCK_HW_SIZE >= HW:
78 hw_offset = tl.arange(0, HW)
79 hw_iter = tl.cdiv(HW, BLOCK_HW_SIZE)
81 if W is None:
82 W_ptr = None
83 else:
84 W_ptr = W + split_group * SPLIT * group_size
85 if B is None:
86 B_ptr = None
87 else:
88 B_ptr = B + split_group * SPLIT * group_size
90 Mean_ptr = Mean + split_n * num_groups + split_group * SPLIT
91 Rstd_ptr = Rstd + split_n * num_groups + split_group * SPLIT
93 xy_offset = (
94 split_n * C * HW
95 + split_group * SPLIT * real_num_elements
96 + group_offset[:, None] * HW
97 + hw_offset[None, :]
98 )
100 ub = SPLIT
101 if (div_mod != 0) and ((split_group + 1) == div_v):
102 ub = div_mod
103 for idx in range(0, ub):
104 if BLOCK_HW_SIZE >= HW:
105 tmp = tl.load(X + xy_offset, cache_modifier=".cg").to(tl.float32)
106 mean = tl.sum(tmp) / real_num_elements
107 x = tmp - mean
108 var = tl.sum(x * x) / real_num_elements
109 var = tl.rsqrt(var + eps)
111 tl.store(Mean_ptr + idx, mean)
112 tl.store(Rstd_ptr + idx, var)
114 if W_ptr is None:
115 weight = 1
116 else:
117 weight = tl.load(W_ptr + group_offset, cache_modifier=".cg")[:, None]
118 if B_ptr is None:
119 bias = 0
120 else:
121 bias = tl.load(B_ptr + group_offset, cache_modifier=".cg")[:, None]
122 tmp = (tmp - mean) * var
123 tmp = tmp * weight + bias
124 tl.store(Y + xy_offset, tmp)
125 else:
126 mean = tl.zeros([BLOCK_GROUP_SIZE, BLOCK_HW_SIZE], tl.float32)
127 var = tl.zeros([BLOCK_GROUP_SIZE, BLOCK_HW_SIZE], tl.float32)
128 for idy in range(0, hw_iter):
129 xy_mask = (
130 group_offset[:, None] < group_size
131 and (idy * BLOCK_HW_SIZE + hw_offset[None, :]) < HW
132 )
133 tmp = tl.load(
134 X + idy * BLOCK_HW_SIZE + xy_offset,
135 mask=xy_mask,
136 other=0.0,
137 cache_modifier=".cg",
138 ).to(tl.float32)
139 mean += tmp
140 var += tmp * tmp
141 mean = tl.sum(mean) / real_num_elements
142 var = tl.sum(var) / real_num_elements - (mean * mean)
143 var = tl.rsqrt(var + eps)
144 tl.store(Mean_ptr + idx, mean)
145 tl.store(Rstd_ptr + idx, var)
147 if W_ptr is None:
148 weight = 1
149 else:
150 weight = tl.load(W_ptr + group_offset, cache_modifier=".cg")[:, None]
151 if B_ptr is None:
152 bias = 0
153 else:
154 bias = tl.load(B_ptr + group_offset, cache_modifier=".cg")[:, None]
156 for idy in range(0, hw_iter):
157 xy_mask = (
158 group_offset[:, None] < group_size
159 and (idy * BLOCK_HW_SIZE + hw_offset[None, :]) < HW
160 )
161 tmp = tl.load(
162 X + idy * BLOCK_HW_SIZE + xy_offset,
163 mask=xy_mask,
164 other=0.0,
165 cache_modifier=".cg",
166 ).to(tl.float32)
167 tmp = (tmp - mean) * var
168 tmp = tmp * weight + bias
169 tl.store(Y + idy * BLOCK_HW_SIZE + xy_offset, tmp, mask=xy_mask)
171 xy_offset += real_num_elements
172 group_offset += group_size
175def group_norm_backward_kernel_opt_prune(configs, named_args, **kwargs):
176 pruned_configs = []
177 hw = kwargs["HW"]
178 all_sizes = []
179 for config in configs:
180 BLOCK_HW_SIZE = config.kwargs["BLOCK_HW_SIZE"]
181 if BLOCK_HW_SIZE not in all_sizes:
182 all_sizes.append(BLOCK_HW_SIZE)
183 for config in configs:
184 BLOCK_HW_SIZE = config.kwargs["BLOCK_HW_SIZE"]
185 SPLIT = config.kwargs["SPLIT"]
186 if (hw > 2048) and (BLOCK_HW_SIZE >= 2048) and (SPLIT <= 1):
187 pruned_configs.append(config)
188 elif BLOCK_HW_SIZE > hw:
189 not_step_bigger = False
190 for size in all_sizes:
191 if (size < BLOCK_HW_SIZE) and (size > hw):
192 not_step_bigger = True
193 if not not_step_bigger:
194 pruned_configs.append(config)
195 return pruned_configs
198@libentry()
199@triton.autotune(
200 configs=[
201 triton.Config({"SPLIT": s, "BLOCK_HW_SIZE": size}, num_stages=3, num_warps=1)
202 for s in [1, 4, 6, 8]
203 for size in [64, 256, 512, 1024, 2048]
204 ],
205 prune_configs_by={"early_config_prune": group_norm_backward_kernel_opt_prune},
206 key=["X", "group_size", "C", "HW", "num_groups"],
207)
208@triton.jit()
209def group_norm_backward_kernel_opt(
210 grad_y,
211 X,
212 W,
213 Mean,
214 Rstd,
215 num_groups,
216 group_size,
217 grad_x,
218 C,
219 HW: tl.constexpr,
220 BLOCK_GROUP_SIZE: tl.constexpr,
221 BLOCK_HW_SIZE: tl.constexpr,
222 SPLIT: tl.constexpr,
223):
224 pid = tl.program_id(0)
225 div_v = tl.cdiv(num_groups, SPLIT)
226 div_mod = num_groups % SPLIT
227 split_group = pid % div_v
228 split_n = pid // div_v
229 real_num_elements = group_size * HW
230 hw_iter = tl.cdiv(HW, BLOCK_HW_SIZE)
232 group_offset = tl.arange(0, BLOCK_GROUP_SIZE)
233 if BLOCK_HW_SIZE >= HW:
234 hw_offset = tl.arange(0, HW)
235 else:
236 hw_offset = tl.arange(0, BLOCK_HW_SIZE)
238 if W is None:
239 W_ptr = None
240 else:
241 W_ptr = W + split_group * SPLIT * group_size
243 Mean_ptr = Mean + split_n * num_groups + split_group * SPLIT
244 Rstd_ptr = Rstd + split_n * num_groups + split_group * SPLIT
246 xy_offset = (
247 split_n * real_num_elements * num_groups
248 + split_group * SPLIT * real_num_elements
249 + group_offset[:, None] * HW
250 + hw_offset[None, :]
251 )
253 ub = SPLIT
254 if (div_mod != 0) and ((split_group + 1) == div_v):
255 ub = div_mod
256 for idx in range(0, ub):
257 wb_mask = group_offset < C
259 if W_ptr is None:
260 weight = 1
261 else:
262 weight = tl.load(
263 W_ptr + group_offset, mask=wb_mask, other=0.0, cache_modifier=".cg"
264 ).to(tl.float32)[:, None]
265 rstd = tl.load(Rstd_ptr + idx).to(tl.float32)
266 mean = tl.load(Mean_ptr + idx).to(tl.float32)
268 if BLOCK_HW_SIZE >= HW:
269 dY_val = tl.load(grad_y + xy_offset, cache_modifier=".cg").to(tl.float32)
270 X_val = tl.load(X + xy_offset, cache_modifier=".cg").to(tl.float32)
272 x_hat = rstd * (X_val - mean)
273 dx_hat = weight * dY_val
275 grad_dx_hat_sum = tl.sum(dx_hat)
276 grad_x_hat_sum = tl.sum(dx_hat * x_hat)
278 dx = rstd * (
279 dx_hat - (grad_dx_hat_sum + x_hat * grad_x_hat_sum) / real_num_elements
280 )
282 tl.store(grad_x + xy_offset, dx)
283 else:
284 grad_dx_hat_accum = tl.zeros([BLOCK_GROUP_SIZE, BLOCK_HW_SIZE], tl.float32)
285 grad_x_hat_accum = tl.zeros([BLOCK_GROUP_SIZE, BLOCK_HW_SIZE], tl.float32)
287 for idy in range(0, hw_iter):
288 xy_mask = (group_offset[:, None] < C) & (
289 (idy * BLOCK_HW_SIZE + hw_offset[None, :]) < HW
290 )
291 dY_val = tl.load(
292 grad_y + idy * BLOCK_HW_SIZE + xy_offset,
293 mask=xy_mask,
294 other=0.0,
295 cache_modifier=".cg",
296 ).to(tl.float32)
297 X_val = tl.load(
298 X + idy * BLOCK_HW_SIZE + xy_offset,
299 mask=xy_mask,
300 other=0.0,
301 cache_modifier=".cg",
302 ).to(tl.float32)
304 x_hat = tl.where(xy_mask, rstd * (X_val - mean), 0.0)
305 dx_hat = weight * dY_val
306 grad_dx_hat_accum += dx_hat
307 grad_x_hat_accum += dx_hat * x_hat
309 grad_dx_hat_total = tl.sum(grad_dx_hat_accum)
310 grad_x_hat_total = tl.sum(grad_x_hat_accum)
312 for idy in range(0, hw_iter):
313 xy_mask = (group_offset[:, None] < C) & (
314 (idy * BLOCK_HW_SIZE + hw_offset[None, :]) < HW
315 )
316 dY_val = tl.load(
317 grad_y + idy * BLOCK_HW_SIZE + xy_offset,
318 mask=xy_mask,
319 other=0.0,
320 cache_modifier=".cg",
321 ).to(tl.float32)
322 X_val = tl.load(
323 X + idy * BLOCK_HW_SIZE + xy_offset,
324 mask=xy_mask,
325 other=0.0,
326 cache_modifier=".cg",
327 ).to(tl.float32)
329 x_hat = tl.where(xy_mask, rstd * (X_val - mean), 0.0)
330 dx_hat = weight * dY_val
331 dx = rstd * (
332 dx_hat
333 - (grad_dx_hat_total + x_hat * grad_x_hat_total) / real_num_elements
334 )
336 tl.store(grad_x + idy * BLOCK_HW_SIZE + xy_offset, dx, mask=xy_mask)
338 xy_offset += real_num_elements
339 group_offset += group_size
342def weight_bias_backward_kernel_opt_prune(configs, named_args, **kwargs):
343 pruned_configs = []
344 pruned_configs_cached = []
345 n = named_args["N"]
346 hw = kwargs["HW"]
347 all_sizes = []
348 for config in configs:
349 BLOCK_HW_SIZE = config.kwargs["BLOCK_HW_SIZE"]
350 if BLOCK_HW_SIZE not in all_sizes:
351 all_sizes.append(BLOCK_HW_SIZE)
352 for config in configs:
353 BLOCK_HW_SIZE = config.kwargs["BLOCK_HW_SIZE"]
354 BLOCK_N = config.kwargs["BLOCK_N"]
355 if (hw > 2048) and (BLOCK_HW_SIZE >= 2048) and (BLOCK_N <= 4):
356 pruned_configs_cached.append(config)
357 elif BLOCK_HW_SIZE > hw:
358 not_step_bigger = False
359 for size in all_sizes:
360 if (size < BLOCK_HW_SIZE) and (size > hw):
361 not_step_bigger = True
362 if not not_step_bigger:
363 pruned_configs_cached.append(config)
364 # remove some block n
365 for config in pruned_configs_cached:
366 block_n = config.kwargs["BLOCK_N"]
367 if n % block_n == 0:
368 pruned_configs.append(config)
369 return pruned_configs
372@libentry()
373@triton.autotune(
374 configs=[
375 triton.Config({"BLOCK_N": bn, "BLOCK_HW_SIZE": size}, num_stages=3, num_warps=1)
376 for bn in [1, 4, 8, 16]
377 for size in [512, 1024, 2048]
378 ],
379 prune_configs_by={"early_config_prune": weight_bias_backward_kernel_opt_prune},
380 key=["X", "N", "C", "HW", "num_groups"],
381)
382@triton.jit
383def weight_bias_backward_kernel_opt(
384 dY,
385 X,
386 Mean,
387 Rstd,
388 dW,
389 dB,
390 num_groups,
391 group_size,
392 N,
393 C,
394 HW: tl.constexpr,
395 BLOCK_N: tl.constexpr,
396 BLOCK_HW_SIZE: tl.constexpr,
397):
398 pid = tl.program_id(0)
399 pnum = tl.num_programs(axis=0)
400 C_SPLIT = tl.cdiv(C, pnum)
401 N_SPLIT = tl.cdiv(N, BLOCK_N)
402 hw_iter = tl.cdiv(HW, BLOCK_HW_SIZE)
404 n_offset = tl.arange(0, BLOCK_N)
405 hw_offset = tl.arange(0, BLOCK_HW_SIZE)
406 if BLOCK_HW_SIZE >= HW:
407 hw_offset = tl.arange(0, HW)
409 lb = pid * C_SPLIT
410 ub = tl.minimum((pid + 1) * C_SPLIT, C)
411 for c_start in range(lb, ub):
412 if BLOCK_HW_SIZE >= HW:
413 dY_ptr = dY + c_start * HW + n_offset[:, None] * C * HW + hw_offset[None, :]
414 x_ptr = X + c_start * HW + n_offset[:, None] * C * HW + hw_offset[None, :]
415 grad_y = tl.load(dY_ptr, cache_modifier=".cg").to(tl.float32)
417 x = tl.load(x_ptr, cache_modifier=".cg")
418 x_f32 = x.to(tl.float32)
420 mean_ptr = Mean + c_start // group_size + n_offset * num_groups
421 rstd_ptr = Rstd + c_start // group_size + n_offset * num_groups
423 mean = tl.load(mean_ptr, cache_modifier=".cg").to(tl.float32)[:, None]
424 rstd = tl.load(rstd_ptr, cache_modifier=".cg").to(tl.float32)[:, None]
426 dB_val = tl.sum(grad_y)
427 dW_val = tl.sum((x_f32 - mean) * rstd * grad_y)
429 for n_start in range(1, N_SPLIT):
430 new_n_offset = n_start * BLOCK_N + n_offset
432 dY_ptr = (
433 dY
434 + c_start * HW
435 + new_n_offset[:, None] * C * HW
436 + hw_offset[None, :]
437 )
438 x_ptr = (
439 X
440 + c_start * HW
441 + new_n_offset[:, None] * C * HW
442 + hw_offset[None, :]
443 )
444 grad_y = tl.load(dY_ptr, cache_modifier=".cg").to(tl.float32)
446 x = tl.load(x_ptr, cache_modifier=".cg")
447 x_f32 = x.to(tl.float32)
449 mean_ptr = Mean + c_start // group_size + new_n_offset * num_groups
450 rstd_ptr = Rstd + c_start // group_size + new_n_offset * num_groups
452 mean = tl.load(mean_ptr, cache_modifier=".cg").to(tl.float32)[:, None]
453 rstd = tl.load(rstd_ptr, cache_modifier=".cg").to(tl.float32)[:, None]
455 dB_val += tl.sum(grad_y)
456 dW_val += tl.sum((x_f32 - mean) * rstd * grad_y)
458 if dW is not None:
459 tl.store(dW + c_start, dW_val)
460 if dB is not None:
461 tl.store(dB + c_start, dB_val)
462 else:
463 xy_mask = (n_offset[:, None] < N) & (hw_offset[None, :] < HW)
465 dY_ptr = dY + c_start * HW + n_offset[:, None] * C * HW + hw_offset[None, :]
466 x_ptr = X + c_start * HW + n_offset[:, None] * C * HW + hw_offset[None, :]
467 grad_y = tl.load(dY_ptr, cache_modifier=".cg").to(tl.float32)
469 x = tl.load(x_ptr, cache_modifier=".cg")
470 x_f32 = x.to(tl.float32)
472 mean_ptr = Mean + c_start // group_size + n_offset * num_groups
473 rstd_ptr = Rstd + c_start // group_size + n_offset * num_groups
475 mean = tl.load(mean_ptr, cache_modifier=".cg").to(tl.float32)[:, None]
476 rstd = tl.load(rstd_ptr, cache_modifier=".cg").to(tl.float32)[:, None]
478 dB_val = tl.sum(grad_y)
479 dW_val = tl.sum((x_f32 - mean) * rstd * grad_y)
481 for idx in range(1, hw_iter):
482 xy_mask = (n_offset[:, None] < N) & (
483 (idx * BLOCK_HW_SIZE + hw_offset[None, :]) < HW
484 )
485 dY_ptr = (
486 dY
487 + c_start * HW
488 + n_offset[:, None] * C * HW
489 + hw_offset[None, :]
490 + idx * BLOCK_HW_SIZE
491 )
492 x_ptr = (
493 X
494 + c_start * HW
495 + n_offset[:, None] * C * HW
496 + hw_offset[None, :]
497 + idx * BLOCK_HW_SIZE
498 )
500 grad_y = tl.load(
501 dY_ptr, mask=xy_mask, other=0.0, cache_modifier=".cg"
502 ).to(tl.float32)
503 x = tl.load(x_ptr, mask=xy_mask, other=0.0, cache_modifier=".cg")
504 x_f32 = x.to(tl.float32)
505 dB_val += tl.sum(grad_y)
506 x_f32 = tl.where(xy_mask, x_f32 - mean, 0.0)
507 dW_val += tl.sum(x_f32 * rstd * grad_y)
509 for n_start in range(1, N_SPLIT):
510 new_n_offset = n_start * BLOCK_N + n_offset
511 xy_mask = (new_n_offset[:, None] < N) & (hw_offset[None, :] < HW)
513 dY_ptr = (
514 dY
515 + c_start * HW
516 + new_n_offset[:, None] * C * HW
517 + hw_offset[None, :]
518 )
519 x_ptr = (
520 X
521 + c_start * HW
522 + new_n_offset[:, None] * C * HW
523 + hw_offset[None, :]
524 )
525 grad_y = tl.load(
526 dY_ptr, mask=xy_mask, other=0.0, cache_modifier=".cg"
527 ).to(tl.float32)
529 x = tl.load(x_ptr, mask=xy_mask, other=0.0, cache_modifier=".cg")
530 x_f32 = x.to(tl.float32)
532 mean_ptr = Mean + c_start // group_size + new_n_offset * num_groups
533 rstd_ptr = Rstd + c_start // group_size + new_n_offset * num_groups
535 mean = tl.load(mean_ptr, cache_modifier=".cg").to(tl.float32)[:, None]
536 rstd = tl.load(rstd_ptr, cache_modifier=".cg").to(tl.float32)[:, None]
538 dB_val += tl.sum(grad_y)
539 dW_val += tl.sum((x_f32 - mean) * rstd * grad_y)
541 for idx in range(1, hw_iter):
542 xy_mask = (new_n_offset[:, None] < N) & (
543 (idx * BLOCK_HW_SIZE + hw_offset[None, :]) < HW
544 )
545 dY_ptr = (
546 dY
547 + c_start * HW
548 + new_n_offset[:, None] * C * HW
549 + hw_offset[None, :]
550 + idx * BLOCK_HW_SIZE
551 )
552 x_ptr = (
553 X
554 + c_start * HW
555 + new_n_offset[:, None] * C * HW
556 + hw_offset[None, :]
557 + idx * BLOCK_HW_SIZE
558 )
560 grad_y = tl.load(
561 dY_ptr, mask=xy_mask, other=0.0, cache_modifier=".cg"
562 ).to(tl.float32)
563 x = tl.load(x_ptr, mask=xy_mask, other=0.0, cache_modifier=".cg")
564 x_f32 = x.to(tl.float32)
565 dB_val += tl.sum(grad_y)
566 x_f32 = tl.where(xy_mask, x_f32 - mean, 0.0)
567 dW_val += tl.sum(x_f32 * rstd * grad_y)
568 if dW is not None:
569 tl.store(dW + c_start, dW_val)
570 if dB is not None:
571 tl.store(dB + c_start, dB_val)
574def group_norm(input, weight, bias, N, C, HxW, group, eps=1e-05):
575 logger.debug("GEMS_CAMBRICON GROUPNORM FORWARD")
576 group_size = C // group
577 input = input.contiguous()
578 if weight is not None:
579 weight = weight.contiguous()
580 if bias is not None:
581 bias = bias.contiguous()
582 y = torch.empty_like(input)
583 mean = torch.empty((N, group), dtype=input.dtype, device=input.device)
584 rstd = torch.empty((N, group), dtype=input.dtype, device=input.device)
585 grid = lambda meta: (N * triton.cdiv(group, meta["SPLIT"]),)
587 with torch_device_fn.device(input.device):
588 group_norm_kernel_opt[grid](
589 input,
590 y,
591 weight,
592 bias,
593 mean,
594 rstd,
595 group_size,
596 C,
597 group,
598 eps,
599 HW=HxW,
600 BLOCK_GROUP_SIZE=group_size,
601 )
602 return y, mean, rstd
605def group_norm_backward(
606 grad_out, input, mean, rstd, weight, N, C, HxW, group, output_mask
607):
608 logger.debug("GEMS_CAMBRICON GROUPNORM BACKWARD")
610 grad_out = grad_out.contiguous()
611 input = input.contiguous()
612 mean = mean.contiguous()
613 rstd = rstd.contiguous()
614 weight = None if weight is None else weight.contiguous()
615 group_size = triton.cdiv(C, group)
617 if output_mask[0]:
618 grad_inp = torch.empty_like(input)
619 grid = lambda meta: (N * triton.cdiv(group, meta["SPLIT"]),)
620 with torch_device_fn.device(input.device):
621 group_norm_backward_kernel_opt[grid](
622 grad_out,
623 input,
624 weight,
625 mean,
626 rstd,
627 group,
628 group_size,
629 grad_inp,
630 C,
631 HW=HxW,
632 BLOCK_GROUP_SIZE=group_size,
633 )
634 else:
635 grad_inp = None
637 if output_mask[1] is False and output_mask[2] is False:
638 return grad_inp, None, None
640 weight_grad = torch.empty_like(weight) if output_mask[1] else None
641 bias_grad = torch.empty_like(weight) if output_mask[2] else None
642 with torch_device_fn.device(input.device):
643 weight_bias_backward_kernel_opt[(TOTAL_CORE_NUM, 1, 1)](
644 grad_out,
645 input,
646 mean,
647 rstd,
648 weight_grad,
649 bias_grad,
650 group,
651 group_size,
652 N,
653 C,
654 HW=HxW,
655 )
656 return grad_inp, weight_grad, bias_grad