Coverage for src/flag_gems/runtime/backend/_kunlunxin/fused/instance_norm.py: 0%
345 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
1import logging
2import math
3from typing import Optional
5import torch
6import triton
7import triton.language as tl
9from flag_gems import runtime
10from flag_gems.runtime import torch_device_fn
11from flag_gems.utils import libentry
12from flag_gems.utils.type_utils import get_accumulator_dtype
14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15Tensor = torch.Tensor
18@triton.jit
19def prev_multiple_of(a, b):
20 # the largest x<a that x%b ==0
21 return tl.cdiv(a, b) * b - b
24@libentry()
25@triton.autotune(
26 configs=runtime.get_tuned_config("instancenorm"),
27 key=["M", "N"],
28)
29@triton.jit(do_not_specialize=["eps"])
30def instance_norm_persistent_kernel(
31 in_ptr,
32 out_ptr,
33 weight_ptr,
34 bias_ptr,
35 out_mean_ptr, # pointer to the mean
36 out_rstd_ptr, # pointer to the 1/std
37 M, # M = B * C
38 N,
39 C,
40 eps,
41 TILE_N: tl.constexpr,
42 HAS_WEIGHT_BIAS: tl.constexpr,
43):
44 # using 1d tile makes code clean
45 # Map the program id to the row of X and Y it should compute.
46 pid = tl.program_id(0)
47 m_mask = pid < M
48 c_offsets = pid % C
50 n_offsets = tl.arange(0, TILE_N)
51 mask = n_offsets < N
53 x = tl.load(in_ptr + pid * N + n_offsets, mask, other=0.0).to(tl.float32)
54 m = tl.sum(x) / N
55 d = x - m # deviation
56 s = tl.where(mask, d * d, 0)
57 sum_square = tl.sum(s) # sum of square of deviation
58 var = sum_square / N
59 rstd = tl.math.rsqrt(var + eps)
61 tl.store(out_mean_ptr + pid, m)
62 tl.store(out_rstd_ptr + pid, rstd)
64 if HAS_WEIGHT_BIAS:
65 w = tl.load(weight_ptr + c_offsets, mask=m_mask)
66 b = tl.load(bias_ptr + c_offsets, mask=m_mask)
67 out = (x - m) * rstd * w + b
68 else:
69 out = (x - m) * rstd
71 tl.store(out_ptr + pid * N + n_offsets, out, mask=mask)
74@libentry()
75# @triton.autotune(
76# configs=runtime.get_tuned_config("instancenorm"),
77# key=["M", "N"],
78# )
79@triton.jit(do_not_specialize=["eps"])
80def instance_norm_persistent_kernel_multiline(
81 in_ptr,
82 out_ptr,
83 weight_ptr,
84 bias_ptr,
85 out_mean_ptr, # pointer to the mean
86 out_rstd_ptr, # pointer to the 1/std
87 M, # M = B * C
88 N,
89 C,
90 eps,
91 TILE_M: tl.constexpr,
92 TILE_N: tl.constexpr,
93 HAS_WEIGHT_BIAS: tl.constexpr,
94):
95 # Map the program id to the row of X and Y it should compute.
96 pid = tl.program_id(0)
97 m_offsets = pid * TILE_M + tl.arange(0, TILE_M)
98 m_mask = m_offsets < M
99 c_offsets = m_offsets % C
101 n_offsets = tl.arange(0, TILE_N)[None, :]
102 n_mask = n_offsets < N
103 mask = m_mask[:, None] & n_mask
105 x = tl.load(in_ptr + m_offsets[:, None] * N + n_offsets, mask, other=0.0).to(
106 tl.float32
107 )
108 m = tl.sum(x, axis=1) / N
109 d = x - m[:, None] # deviation
110 s = tl.where(mask, d * d, 0)
111 sum_square = tl.sum(s, axis=1) # sum of square of deviation
112 var = sum_square / N
113 rstd = tl.math.rsqrt(var + eps)
115 tl.store(out_mean_ptr + m_offsets, m, mask=m_mask)
116 tl.store(out_rstd_ptr + m_offsets, rstd, mask=m_mask)
118 if HAS_WEIGHT_BIAS:
119 w = tl.load(weight_ptr + c_offsets, mask=m_mask)
120 b = tl.load(bias_ptr + c_offsets, mask=m_mask)
121 out = (x - m[:, None]) * rstd[:, None] * w[:, None] + b[:, None]
122 else:
123 out = (x - m[:, None]) * rstd[:, None]
125 tl.store(out_ptr + m_offsets[:, None] * N + n_offsets, out, mask=mask)
128def instance_norm_loop_kernel_heur_tile_n(args):
129 return 8192
130 import builtins
132 return builtins.min(args["N"], 8192)
135@libentry()
136# @triton.autotune(
137# configs=runtime.get_tuned_config("instance_norm_loop"),
138# key=["M", "N"],
139# )
140@triton.heuristics(
141 values={
142 "TILE_N": instance_norm_loop_kernel_heur_tile_n,
143 },
144)
145@triton.jit(do_not_specialize=["eps"])
146def instance_norm_loop_kernel(
147 in_ptr,
148 out_ptr,
149 weight_ptr,
150 bias_ptr,
151 out_mean_ptr, # pointer to the mean
152 out_rstd_ptr, # pointer to the 1/std
153 M, # M = B * C
154 N,
155 C,
156 eps,
157 TILE_N: tl.constexpr,
158 HAS_WEIGHT_BIAS: tl.constexpr,
159):
160 # Map the program id to the row of X and Y it should compute.
161 pid = tl.program_id(0)
162 m_mask = pid < M
163 c_offsets = pid % C
165 # Compute mean
166 m = tl.zeros((TILE_N,), dtype=tl.float32) # mean
167 s = tl.zeros((TILE_N,), dtype=tl.float32) # sum((x - m)^2)
168 cnt = tl.zeros((TILE_N,), dtype=tl.int32)
169 num_steps = tl.cdiv(N, TILE_N)
170 for step in range(0, num_steps - 1, 1):
171 start_n = step * TILE_N
172 n_offsets = start_n + tl.arange(0, TILE_N)
173 x = tl.load(in_ptr + pid * N + n_offsets).to(tl.float32)
174 new_m = m + (x - m) / (step + 1)
175 new_s = s + (x - new_m) * (x - m)
176 cnt += 1
177 m = new_m
178 s = new_s
180 # the last step
181 for step in range(num_steps - 1, num_steps, 1):
182 start_n = step * TILE_N
183 n_offsets = start_n + tl.arange(0, TILE_N)
184 mask = n_offsets < N
185 x = tl.load(in_ptr + pid * N + n_offsets, mask=mask).to(tl.float32)
186 new_m = tl.where(mask, m + (x - m) / (step + 1), m)
187 new_s = tl.where(mask, s + (x - new_m) * (x - m), s)
188 cnt += mask.to(tl.int32)
189 m = new_m
190 s = new_s
192 final_m = tl.sum(m * cnt) / N
193 var = tl.sum(s + cnt * (m - final_m) * (m - final_m)) / N
194 rstd = tl.math.rsqrt(var + eps)
195 m = final_m
196 # Write mean / rstd
197 tl.store(out_mean_ptr + pid, m)
198 tl.store(out_rstd_ptr + pid, rstd)
200 if HAS_WEIGHT_BIAS:
201 w = tl.load(weight_ptr + c_offsets, mask=m_mask)
202 b = tl.load(bias_ptr + c_offsets, mask=m_mask)
203 else:
204 w = 1
205 b = 0
207 # reverse the order of the second sweep
208 # Normalize and apply linear transformation
209 prev_multiple = prev_multiple_of(N, TILE_N)
210 # the first step, masking is needed
211 for start_n in range(0, TILE_N, TILE_N):
212 n_offsets = (prev_multiple - start_n) + tl.arange(0, TILE_N)
213 mask = n_offsets < N
214 x = tl.load(
215 in_ptr + pid * N + n_offsets,
216 mask=mask,
217 other=0.0,
218 eviction_policy="evict_first",
219 ).to(tl.float32)
220 out = w * (x - m) * rstd + b
221 tl.store(out_ptr + pid * N + n_offsets, out, mask=mask)
223 for start_n in range(TILE_N, N, TILE_N):
224 n_offsets = (prev_multiple - start_n) + tl.arange(0, TILE_N)
225 x = tl.load(in_ptr + pid * N + n_offsets, eviction_policy="evict_first").to(
226 tl.float32
227 )
228 out = w * (x - m) * rstd + b
229 tl.store(out_ptr + pid * N + n_offsets, out)
232@libentry()
233@triton.jit(do_not_specialize=["eps"])
234def instancenorm_fwd_kernel_xpu(
235 X,
236 Y,
237 W,
238 B,
239 MEAN,
240 RSTRD,
241 M: tl.constexpr,
242 N: tl.constexpr,
243 C: tl.constexpr,
244 eps: tl.constexpr,
245 HAS_WEIGHT_BIAS: tl.constexpr,
246 XBLOCK: tl.constexpr,
247 RBLOCK: tl.constexpr,
248):
249 pid = tl.program_id(0)
250 xoffset = pid * XBLOCK
251 _xindex = xoffset + tl.arange(0, XBLOCK)
252 xindex = _xindex[:, None]
253 xmask = xindex < M
254 rbase = tl.arange(0, RBLOCK)[None, :]
255 _mean = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
256 _var = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
258 for roffset in range(0, N, RBLOCK):
259 rindex = roffset + rbase
260 rmask = rindex < N
261 x = tl.load(X + (rindex + (N * xindex)), rmask & xmask, other=0.0).to(
262 tl.float32
263 )
264 _mean = _mean + tl.broadcast_to(x, [XBLOCK, RBLOCK])
265 _var = _var + tl.broadcast_to(x * x, [XBLOCK, RBLOCK])
267 mean = tl.sum(_mean, 1)[:, None] / N
268 var = tl.sum(_var, 1)[:, None] / N
269 var_mean = var - mean * mean
270 rstd = 1 / tl.sqrt(var_mean + eps)
272 tl.store(MEAN + xindex, mean, xmask)
273 tl.store(RSTRD + xindex, rstd, xmask)
275 cindex = xindex % C
276 for roffset in range(0, N, RBLOCK):
277 rindex = roffset + rbase
278 rmask = rindex < N
279 x = tl.load(X + (rindex + (N * xindex)), rmask & xmask, other=0.0).to(
280 tl.float32
281 )
282 if HAS_WEIGHT_BIAS:
283 w = tl.load(W + cindex, xmask)
284 b = tl.load(B + cindex, xmask)
285 else:
286 w = 1
287 b = 0
288 x_hat = (x - mean) * rstd
289 y = x_hat * w + b
290 tl.store(Y + (rindex + (N * xindex)), y, rmask & xmask)
293def instance_norm_use_running_stats_kernel_heur_tile_n(args):
294 return 8192
295 import builtins
297 return builtins.min(args["N"], 8192)
300@libentry()
301# @triton.autotune(
302# configs=runtime.get_tuned_config("instancenorm"),
303# key=["M", "N"],
304# )
305@triton.jit(do_not_specialize=["eps"])
306def instance_norm_use_running_stats_kernel(
307 in_ptr,
308 out_ptr,
309 weight_ptr,
310 bias_ptr,
311 running_mean_ptr, # pointer to the mean
312 running_var_ptr, # pointer to the var
313 out_mean_ptr, # pointer to the mean
314 out_rstd_ptr, # pointer to the 1/std
315 M, # M = B * C
316 N,
317 C,
318 eps,
319 TILE_N: tl.constexpr,
320 HAS_WEIGHT_BIAS: tl.constexpr,
321):
322 # using 1d tile makes code clean
323 # Map the program id to the row of X and Y it should compute.
324 pid = tl.program_id(0)
325 m_mask = pid < M
326 c_offsets = pid % C
328 n_offsets = tl.arange(0, TILE_N)
329 mask = n_offsets < N
331 x = tl.load(in_ptr + pid * N + n_offsets, mask, other=0.0).to(tl.float32)
332 m = tl.load(running_mean_ptr + c_offsets, mask=m_mask).to(tl.float32)
333 var = tl.load(running_var_ptr + c_offsets, mask=m_mask).to(tl.float32)
334 rstd = tl.math.rsqrt(var + eps)
336 tl.store(out_mean_ptr + pid, m)
337 tl.store(out_rstd_ptr + pid, rstd)
339 if HAS_WEIGHT_BIAS:
340 w = tl.load(weight_ptr + c_offsets, mask=m_mask).to(tl.float32)
341 b = tl.load(bias_ptr + c_offsets, mask=m_mask).to(tl.float32)
342 out = (x - m) * rstd * w + b
343 else:
344 out = (x - m) * rstd
346 tl.store(out_ptr + pid * N + n_offsets, out, mask=mask)
349@triton.jit
350def update_running_stats_kernel(
351 mean_ptr, # pointer to the mean
352 rstd_ptr, # pointer to the 1/std
353 running_mean_ptr,
354 running_var_ptr,
355 momentum,
356 B,
357 C,
358 N,
359 eps,
360 BLOCK_BATCH_SIZE: tl.constexpr = 1,
361 BLOCK_CHANNEL_SIZE: tl.constexpr = 2048,
362):
363 cid = tl.program_id(0) * BLOCK_CHANNEL_SIZE + tl.arange(0, BLOCK_CHANNEL_SIZE)
364 col_mask = cid < C
365 running_mean = tl.load(running_mean_ptr + cid, mask=col_mask).to(tl.float32)
366 running_var = tl.load(running_var_ptr + cid, mask=col_mask).to(tl.float32)
368 new_mean = tl.zeros((BLOCK_CHANNEL_SIZE,), dtype=tl.float32)
369 new_var = tl.zeros((BLOCK_CHANNEL_SIZE,), dtype=tl.float32)
370 for b in range(0, B, BLOCK_BATCH_SIZE):
371 bid = b * BLOCK_BATCH_SIZE + tl.arange(0, BLOCK_BATCH_SIZE)[:, None]
372 row_mask = bid < B
373 mask = row_mask and col_mask[None, :]
374 mean = tl.load(mean_ptr + bid * C + cid[None, :], mask=mask, other=0.0).to(
375 tl.float32
376 )
377 rstd = tl.load(rstd_ptr + bid * C + cid[None, :], mask=mask, other=0.0).to(
378 tl.float32
379 )
380 var = (
381 (1 / (rstd * rstd) + eps) * N / (N - 1)
382 ) # NOTE: use unbiased var to update running_var
384 new_mean += tl.sum(mean, axis=0)
385 new_var += tl.sum(var, axis=0)
387 new_running_mean = (1 - momentum) * running_mean + momentum * new_mean / B
388 new_running_var = (1 - momentum) * running_var + momentum * new_var / B
390 tl.store(running_mean_ptr + cid, new_running_mean, mask=col_mask)
391 tl.store(running_var_ptr + cid, new_running_var, mask=col_mask)
394def instance_norm_backward_kernel_heur_block_row_size(args):
395 return 1
396 return triton.next_power_of_2(triton.cdiv(args["M"], 12)) # cluster_num
399def instance_norm_backward_kernel_heur_block_col_size(args):
400 import builtins
402 return builtins.min(triton.next_power_of_2(args["N"]), 8192)
405@libentry()
406# @triton.autotune(
407# configs=runtime.get_tuned_config("instance_norm_backward"),
408# key=["M", "N", "C"],
409# )
410@triton.heuristics(
411 values={
412 "BLOCK_ROW_SIZE": instance_norm_backward_kernel_heur_block_row_size,
413 "BLOCK_COL_SIZE": instance_norm_backward_kernel_heur_block_col_size,
414 },
415)
416@triton.jit
417def instance_norm_backward_kernel(
418 dY,
419 X,
420 W,
421 Mean, # [B, C]
422 Rstd, # [B, C]
423 dX,
424 M, # M = B * C
425 N,
426 C,
427 BLOCK_ROW_SIZE: tl.constexpr,
428 BLOCK_COL_SIZE: tl.constexpr,
429 HAS_WEIGHT_BIAS: tl.constexpr,
430):
431 pid = tl.program_id(0) * BLOCK_ROW_SIZE + tl.arange(0, BLOCK_ROW_SIZE)[:, None]
432 c_offsets = pid % C
433 row_mask = pid < M
434 dY += pid * N
435 X += pid * N
436 dX += pid * N
437 Mean += pid
438 Rstd += pid
440 mean = tl.load(Mean, mask=row_mask, other=0.0).to(tl.float32)
441 rstd = tl.load(Rstd, mask=row_mask, other=1.0).to(tl.float32)
442 if HAS_WEIGHT_BIAS:
443 w = tl.load(W + c_offsets, mask=row_mask).to(tl.float32)
444 else:
445 w = 1
447 dx_part2 = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
448 dx_part3 = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
450 for off in range(0, N, BLOCK_COL_SIZE):
451 cols = off + tl.arange(0, BLOCK_COL_SIZE)
452 col_mask = cols[None, :] < N
453 mask = row_mask and col_mask
454 dy = tl.load(dY + cols[None, :], mask).to(tl.float32)
455 x = tl.load(X + cols[None, :], mask).to(tl.float32)
456 x = tl.where(mask, x - mean, 0.0)
457 x_hat = x * rstd
458 dx_hat = dy * w
459 dx_part2 += dx_hat
460 dx_part3 += dx_hat * x_hat
462 dx_2 = tl.sum(dx_part2, axis=1)[:, None]
463 dx_3 = tl.sum(dx_part3, axis=1)[:, None]
465 for off in range(0, N, BLOCK_COL_SIZE):
466 cols = off + tl.arange(0, BLOCK_COL_SIZE)
467 col_mask = cols[None, :] < N
468 mask = row_mask and col_mask
469 dy = tl.load(dY + cols[None, :], mask).to(tl.float32)
470 x = tl.load(X + cols[None, :], mask).to(tl.float32)
471 x = tl.where(mask, x - mean, 0.0)
472 x_hat = x * rstd
473 dx_hat = dy * w
474 dx = rstd * (dx_hat - (dx_2 + x_hat * dx_3) / N)
475 tl.store(dX + cols, dx, mask=mask)
478def weight_bias_backward_kernel_heur_block_batch_size(args):
479 return 1
480 import builtins
482 return builtins.min(triton.next_power_of_2(args["N"]), 8192)
485def weight_bias_backward_kernel_heur_block_col_size(args):
486 return triton.next_power_of_2(triton.cdiv(args["C"], 12)) # cluster_num
489@libentry()
490# @triton.autotune(
491# configs=runtime.get_tuned_config("instance_norm_weight_bias_backward"),
492# key=["N", "B", "C"],
493# )
494@triton.heuristics(
495 values={
496 "BLOCK_BATCH_SIZE": weight_bias_backward_kernel_heur_block_batch_size,
497 "BLOCK_COL_SIZE": weight_bias_backward_kernel_heur_block_col_size,
498 },
499)
500@triton.jit
501def weight_bias_backward_kernel(
502 dY,
503 X,
504 Mean, # [B, C]
505 Rstd, # [B, C]
506 dW,
507 dB,
508 M,
509 N,
510 B,
511 C,
512 BLOCK_BATCH_SIZE: tl.constexpr,
513 BLOCK_COL_SIZE: tl.constexpr,
514):
515 cid = tl.program_id(0)[:, None]
516 dW += cid
517 dB += cid
518 c_mask = cid < C
520 accW = tl.zeros([BLOCK_BATCH_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
521 accB = tl.zeros([BLOCK_BATCH_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
523 for b_off in range(0, B, BLOCK_BATCH_SIZE):
524 bid = b_off + tl.arange(0, BLOCK_BATCH_SIZE)[:, None]
525 mid = bid * C + cid
526 row_mask = bid < B
527 mean = tl.load(Mean + mid, mask=row_mask).to(tl.float32)
528 rstd = tl.load(Rstd + mid, mask=row_mask).to(tl.float32)
529 for off in range(0, N, BLOCK_COL_SIZE):
530 cols = off + tl.arange(0, BLOCK_COL_SIZE)
531 col_mask = cols[None, :] < N
532 mask = row_mask and col_mask
533 dy = tl.load(dY + mid * N + cols[None, :], mask).to(tl.float32)
534 x = tl.load(X + mid * N + cols[None, :], mask).to(tl.float32)
535 x = tl.where(mask, x - mean, 0.0)
536 x_hat = x * rstd
537 accW += dy * x_hat
538 accB += dy
539 dw = tl.sum(accW)
540 db = tl.sum(accB)
541 tl.store(dW, dw, mask=c_mask)
542 tl.store(dB, db, mask=c_mask)
545class InstanceNorm(torch.autograd.Function):
546 @staticmethod
547 def forward(
548 ctx,
549 x,
550 weight=None,
551 bias=None,
552 running_mean=None,
553 running_var=None,
554 use_input_stats=False,
555 momentum=0.1,
556 eps=1e-05,
557 cudnn_enable=False,
558 ):
559 logger.debug("GEMS INSTANCENORM FORWARD")
560 assert len(x.shape) in [
561 3,
562 4,
563 5,
564 ], f"x.shape should be [B, C, N] or [B, C, H, W] or [B, C, H, W, L], but got {x.shape}"
565 B, C = x.shape[:2]
566 N = math.prod(x.shape[2:])
567 M = x.numel() // N
569 x = x.contiguous()
570 weight = weight.contiguous() if weight is not None else None
571 bias = bias.contiguous() if bias is not None else None
572 y = torch.empty_like(x)
574 has_weight_bias = weight is not None
575 if has_weight_bias:
576 assert weight is not None and bias is not None
578 has_running_stats = running_mean is not None
579 if has_running_stats:
580 assert (
581 N > 1
582 ), f"Expected more than 1 spatial element when training, got input size {x.shape}"
583 assert (
584 running_mean is not None and running_var is not None
585 ), "running_mean and running_var should not both be None"
586 assert (
587 running_mean.shape == running_var.shape and running_mean.shape[0] == C
588 ), f"running_mean and running_var should have shape as {[C,]}"
589 assert (
590 running_mean.dtype == running_var.dtype
591 ), "running_mean and running_var should have the same dtype"
592 if not use_input_stats:
593 assert (
594 has_running_stats
595 ), "Expected running_mean and running_var to be defined when use_input_stats is False"
597 # NOTE: when the input is half-precision(either float16 or bfloat16)
598 # these statistical data saved for backward is in single precision
599 acc_type = get_accumulator_dtype(x.dtype)
600 mean = torch.empty(size=(B, C), dtype=acc_type, device=x.device)
601 rstd = torch.empty(size=(B, C), dtype=acc_type, device=x.device)
603 with torch_device_fn.device(x.device):
604 if use_input_stats:
605 grid = (12, 1, 1)
606 instancenorm_fwd_kernel_xpu[grid](
607 x,
608 y,
609 weight,
610 bias,
611 mean,
612 rstd,
613 M,
614 N,
615 C,
616 eps,
617 HAS_WEIGHT_BIAS=has_weight_bias,
618 XBLOCK=triton.next_power_of_2(triton.cdiv(M, 12)),
619 RBLOCK=8192,
620 isCloseUnrollControl=True,
621 buffer_size_limit=512,
622 )
623 if has_running_stats and use_input_stats: # update running stats
624 grid = lambda meta: (
625 triton.cdiv(C, meta["BLOCK_CHANNEL_SIZE"]),
626 1,
627 1,
628 )
629 update_running_stats_kernel[grid](
630 mean,
631 rstd,
632 running_mean,
633 running_var,
634 momentum,
635 B,
636 C,
637 N,
638 eps,
639 isCloseCoreTiling=True,
640 isCloseVectorization=True,
641 isCloseUnrollControl=True,
642 )
643 else: # use running stats instead of input stats
644 TILE_N = triton.next_power_of_2(N)
645 grid = (M, 1, 1)
646 instance_norm_use_running_stats_kernel[grid](
647 x,
648 y,
649 weight,
650 bias,
651 running_mean,
652 running_var,
653 mean,
654 rstd,
655 M,
656 N,
657 C,
658 eps,
659 TILE_N,
660 HAS_WEIGHT_BIAS=has_weight_bias,
661 isCloseUnrollControl=True,
662 )
664 ctx.save_for_backward(x, weight, mean, rstd)
665 ctx.M = M
666 ctx.N = N
667 ctx.C = C
668 ctx.has_weight_bias = has_weight_bias
669 return y
671 @staticmethod
672 def backward(ctx, out_grad):
673 logger.debug("GEMS INSTANCENORM BACKWARD")
674 out_grad = out_grad.contiguous()
675 (x, weight, mean, rstd) = ctx.saved_tensors
676 M = ctx.M
677 N = ctx.N
678 C = ctx.C
679 B = M // C
681 with torch_device_fn.device(x.device):
682 in_grad = torch.empty_like(x)
683 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_ROW_SIZE"]), 1, 1)
685 instance_norm_backward_kernel[grid](
686 out_grad,
687 x,
688 weight,
689 mean,
690 rstd,
691 in_grad,
692 M,
693 N,
694 C,
695 HAS_WEIGHT_BIAS=ctx.has_weight_bias,
696 isCloseCoreTiling=True,
697 )
699 if ctx.has_weight_bias:
700 grid = lambda meta: (C, 1, 1)
701 weight_grad = torch.empty_like(weight)
702 bias_grad = torch.empty_like(weight)
703 weight_bias_backward_kernel[grid](
704 out_grad,
705 x,
706 mean,
707 rstd,
708 weight_grad,
709 bias_grad,
710 M,
711 N,
712 B,
713 C,
714 )
715 else:
716 weight_grad = None
717 bias_grad = None
718 return in_grad, weight_grad, bias_grad, None, None, None, None, None, None
721def instance_norm(
722 input: Tensor,
723 weight: Optional[Tensor] = None,
724 bias: Optional[Tensor] = None,
725 running_mean: Optional[Tensor] = None,
726 running_var: Optional[Tensor] = None,
727 use_input_stats: bool = True,
728 momentum: float = 0.1,
729 eps: float = 1e-5,
730 cudnn_enable: bool = False,
731) -> Tensor:
732 r"""Applies Instance Normalization for each channel in each data sample in a
733 batch.
734 Inputs:
735 input: input tensor of shape :math:`(N, C, *)`
736 weight: weight tensor of shape :math:`(C)`
737 bias: bias tensor of shape :math:`(C)`
738 running_mean: running mean tensor of shape :math:`(C)`
739 running_var: running variance tensor of shape :math:`(C)`
740 use_input_stats: whether to use the mean and variance of the input tensor
741 momentum: momentum value for the running mean and variance
742 eps: epsilon value for numerical stability
743 cudnn_enable: whether to use cudnn for normalization
744 Returns:
745 output tensor of shape :math:`(N, C, *)`
746 """
748 return InstanceNorm.apply(
749 input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps
750 )