Coverage for src/flag_gems/fused/instance_norm.py: 31%
308 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
1import logging
2import math
3from typing import Optional
5import torch
6import triton
7import triton.language as tl
9from flag_gems import runtime
10from flag_gems.runtime import torch_device_fn
11from flag_gems.utils import libentry
12from flag_gems.utils.type_utils import get_accumulator_dtype
14logger = logging.getLogger(__name__)
15Tensor = torch.Tensor
18@triton.jit
19def prev_multiple_of(a, b):
20 # the largest x<a that x%b ==0
21 return tl.cdiv(a, b) * b - b
24@libentry()
25@triton.autotune(
26 configs=runtime.get_tuned_config("instancenorm"),
27 key=["M", "N"],
28)
29@triton.jit(do_not_specialize=["eps"])
30def instance_norm_persistent_kernel(
31 in_ptr,
32 out_ptr,
33 weight_ptr,
34 bias_ptr,
35 out_mean_ptr, # pointer to the mean
36 out_rstd_ptr, # pointer to the 1/std
37 M, # M = B * C
38 N,
39 C,
40 eps,
41 TILE_N: tl.constexpr,
42 HAS_WEIGHT_BIAS: tl.constexpr,
43):
44 # using 1d tile makes code clean
45 # Map the program id to the row of X and Y it should compute.
46 pid = tl.program_id(0)
47 m_mask = pid < M
48 c_offsets = pid % C
50 n_offsets = tl.arange(0, TILE_N)
51 mask = n_offsets < N
53 x = tl.load(in_ptr + pid * N + n_offsets, mask, other=0.0).to(tl.float32)
54 m = tl.sum(x) / N
55 d = x - m # deviation
56 s = tl.where(mask, d * d, 0)
57 sum_square = tl.sum(s) # sum of square of deviation
58 var = sum_square / N
59 rstd = tl.math.rsqrt(var + eps)
61 tl.store(out_mean_ptr + pid, m)
62 tl.store(out_rstd_ptr + pid, rstd)
64 if HAS_WEIGHT_BIAS:
65 w = tl.load(weight_ptr + c_offsets, mask=m_mask)
66 b = tl.load(bias_ptr + c_offsets, mask=m_mask)
67 out = (x - m) * rstd * w + b
68 else:
69 out = (x - m) * rstd
71 tl.store(out_ptr + pid * N + n_offsets, out, mask=mask)
74@libentry()
75@triton.autotune(
76 configs=runtime.get_tuned_config("instancenorm"),
77 key=["M", "N"],
78)
79@triton.jit(do_not_specialize=["eps"])
80def instance_norm_persistent_kernel_multiline(
81 in_ptr,
82 out_ptr,
83 weight_ptr,
84 bias_ptr,
85 out_mean_ptr, # pointer to the mean
86 out_rstd_ptr, # pointer to the 1/std
87 M, # M = B * C
88 N,
89 C,
90 eps,
91 TILE_M: tl.constexpr,
92 TILE_N: tl.constexpr,
93 HAS_WEIGHT_BIAS: tl.constexpr,
94):
95 # Map the program id to the row of X and Y it should compute.
96 pid = tl.program_id(0)
97 m_offsets = pid * TILE_M + tl.arange(0, TILE_M)
98 m_mask = m_offsets < M
99 c_offsets = m_offsets % C
101 n_offsets = tl.arange(0, TILE_N)[None, :]
102 n_mask = n_offsets < N
103 mask = m_mask[:, None] & n_mask
105 x = tl.load(in_ptr + m_offsets[:, None] * N + n_offsets, mask, other=0.0).to(
106 tl.float32
107 )
108 m = tl.sum(x, axis=1) / N
109 d = x - m[:, None] # deviation
110 s = tl.where(mask, d * d, 0)
111 sum_square = tl.sum(s, axis=1) # sum of square of deviation
112 var = sum_square / N
113 rstd = tl.math.rsqrt(var + eps)
115 tl.store(out_mean_ptr + m_offsets, m, mask=m_mask)
116 tl.store(out_rstd_ptr + m_offsets, rstd, mask=m_mask)
118 if HAS_WEIGHT_BIAS:
119 w = tl.load(weight_ptr + c_offsets, mask=m_mask)
120 b = tl.load(bias_ptr + c_offsets, mask=m_mask)
121 out = (x - m[:, None]) * rstd[:, None] * w[:, None] + b[:, None]
122 else:
123 out = (x - m[:, None]) * rstd[:, None]
125 tl.store(out_ptr + m_offsets[:, None] * N + n_offsets, out, mask=mask)
128@libentry()
129@triton.autotune(
130 configs=runtime.get_tuned_config("instance_norm_loop"),
131 key=["M", "N"],
132)
133@triton.jit(do_not_specialize=["eps"])
134def instance_norm_loop_kernel(
135 in_ptr,
136 out_ptr,
137 weight_ptr,
138 bias_ptr,
139 out_mean_ptr, # pointer to the mean
140 out_rstd_ptr, # pointer to the 1/std
141 M, # M = B * C
142 N,
143 C,
144 eps,
145 TILE_N: tl.constexpr,
146 HAS_WEIGHT_BIAS: tl.constexpr,
147):
148 # Map the program id to the row of X and Y it should compute.
149 pid = tl.program_id(0)
150 m_mask = pid < M
151 c_offsets = pid % C
153 # Compute mean
154 m = tl.zeros((TILE_N,), dtype=tl.float32) # mean
155 s = tl.zeros((TILE_N,), dtype=tl.float32) # sum((x - m)^2)
156 cnt = tl.zeros((TILE_N,), dtype=tl.int32)
157 num_steps = tl.cdiv(N, TILE_N)
158 for step in range(0, num_steps - 1, 1):
159 start_n = step * TILE_N
160 n_offsets = start_n + tl.arange(0, TILE_N)
161 x = tl.load(in_ptr + pid * N + n_offsets).to(tl.float32)
162 new_m = m + (x - m) / (step + 1)
163 new_s = s + (x - new_m) * (x - m)
164 cnt += 1
165 m = new_m
166 s = new_s
168 # the last step
169 for step in range(num_steps - 1, num_steps, 1):
170 start_n = step * TILE_N
171 n_offsets = start_n + tl.arange(0, TILE_N)
172 mask = n_offsets < N
173 x = tl.load(in_ptr + pid * N + n_offsets, mask=mask).to(tl.float32)
174 new_m = tl.where(mask, m + (x - m) / (step + 1), m)
175 new_s = tl.where(mask, s + (x - new_m) * (x - m), s)
176 cnt += mask.to(tl.int32)
177 m = new_m
178 s = new_s
180 final_m = tl.sum(m * cnt) / N
181 var = tl.sum(s + cnt * (m - final_m) * (m - final_m)) / N
182 rstd = tl.math.rsqrt(var + eps)
183 m = final_m
184 # Write mean / rstd
185 tl.store(out_mean_ptr + pid, m)
186 tl.store(out_rstd_ptr + pid, rstd)
188 if HAS_WEIGHT_BIAS:
189 w = tl.load(weight_ptr + c_offsets, mask=m_mask)
190 b = tl.load(bias_ptr + c_offsets, mask=m_mask)
191 else:
192 w = 1
193 b = 0
195 # reverse the order of the second sweep
196 # Normalize and apply linear transformation
197 prev_multiple = prev_multiple_of(N, TILE_N)
198 # the first step, masking is needed
199 for start_n in range(0, TILE_N, TILE_N):
200 n_offsets = (prev_multiple - start_n) + tl.arange(0, TILE_N)
201 mask = n_offsets < N
202 x = tl.load(
203 in_ptr + pid * N + n_offsets,
204 mask=mask,
205 other=0.0,
206 eviction_policy="evict_first",
207 ).to(tl.float32)
208 out = w * (x - m) * rstd + b
209 tl.store(out_ptr + pid * N + n_offsets, out, mask=mask)
211 for start_n in range(TILE_N, N, TILE_N):
212 n_offsets = (prev_multiple - start_n) + tl.arange(0, TILE_N)
213 x = tl.load(in_ptr + pid * N + n_offsets, eviction_policy="evict_first").to(
214 tl.float32
215 )
216 out = w * (x - m) * rstd + b
217 tl.store(out_ptr + pid * N + n_offsets, out)
220@libentry()
221@triton.autotune(
222 configs=runtime.get_tuned_config("instancenorm"),
223 key=["M", "N"],
224)
225@triton.jit(do_not_specialize=["eps"])
226def instance_norm_use_running_stats_kernel(
227 in_ptr,
228 out_ptr,
229 weight_ptr,
230 bias_ptr,
231 running_mean_ptr, # pointer to the mean
232 running_var_ptr, # pointer to the var
233 out_mean_ptr, # pointer to the mean
234 out_rstd_ptr, # pointer to the 1/std
235 M, # M = B * C
236 N,
237 C,
238 eps,
239 TILE_N: tl.constexpr,
240 HAS_WEIGHT_BIAS: tl.constexpr,
241):
242 # using 1d tile makes code clean
243 # Map the program id to the row of X and Y it should compute.
244 pid = tl.program_id(0)
245 m_mask = pid < M
246 c_offsets = pid % C
248 n_offsets = tl.arange(0, TILE_N)
249 mask = n_offsets < N
251 x = tl.load(in_ptr + pid * N + n_offsets, mask, other=0.0).to(tl.float32)
252 m = tl.load(running_mean_ptr + c_offsets, mask=m_mask)
253 var = tl.load(running_var_ptr + c_offsets, mask=m_mask)
254 rstd = tl.math.rsqrt(var + eps)
256 tl.store(out_mean_ptr + pid, m)
257 tl.store(out_rstd_ptr + pid, rstd)
259 if HAS_WEIGHT_BIAS:
260 w = tl.load(weight_ptr + c_offsets, mask=m_mask)
261 b = tl.load(bias_ptr + c_offsets, mask=m_mask)
262 out = (x - m) * rstd * w + b
263 else:
264 out = (x - m) * rstd
266 tl.store(out_ptr + pid * N + n_offsets, out, mask=mask)
269@triton.jit
270def update_running_stats_kernel(
271 mean_ptr, # pointer to the mean
272 rstd_ptr, # pointer to the 1/std
273 running_mean_ptr,
274 running_var_ptr,
275 momentum,
276 B,
277 C,
278 N,
279 eps,
280 BLOCK_BATCH_SIZE: tl.constexpr = 1,
281 BLOCK_CHANNEL_SIZE: tl.constexpr = 2048,
282):
283 cid = tl.program_id(0) * BLOCK_CHANNEL_SIZE + tl.arange(0, BLOCK_CHANNEL_SIZE)
284 col_mask = cid < C
285 running_mean = tl.load(running_mean_ptr + cid, mask=col_mask).to(tl.float32)
286 running_var = tl.load(running_var_ptr + cid, mask=col_mask).to(tl.float32)
288 new_mean = tl.zeros((BLOCK_CHANNEL_SIZE,), dtype=tl.float32)
289 new_var = tl.zeros((BLOCK_CHANNEL_SIZE,), dtype=tl.float32)
290 for b in range(0, B, BLOCK_BATCH_SIZE):
291 bid = b * BLOCK_BATCH_SIZE + tl.arange(0, BLOCK_BATCH_SIZE)[:, None]
292 row_mask = bid < B
293 mask = row_mask and col_mask[None, :]
294 mean = tl.load(mean_ptr + bid * C + cid[None, :], mask=mask, other=0.0).to(
295 tl.float32
296 )
297 rstd = tl.load(rstd_ptr + bid * C + cid[None, :], mask=mask, other=0.0).to(
298 tl.float32
299 )
300 var = (
301 (1 / (rstd * rstd) + eps) * N / (N - 1)
302 ) # NOTE: use unbiased var to update running_var
304 new_mean += tl.sum(mean, axis=0)
305 new_var += tl.sum(var, axis=0)
307 new_running_mean = (1 - momentum) * running_mean + momentum * new_mean / B
308 new_running_var = (1 - momentum) * running_var + momentum * new_var / B
310 tl.store(running_mean_ptr + cid, new_running_mean, mask=col_mask)
311 tl.store(running_var_ptr + cid, new_running_var, mask=col_mask)
314@libentry()
315@triton.autotune(
316 configs=runtime.get_tuned_config("instance_norm_backward"),
317 key=["M", "N", "C"],
318)
319@triton.jit
320def instance_norm_backward_kernel(
321 dY,
322 X,
323 W,
324 Mean, # [B, C]
325 Rstd, # [B, C]
326 dX,
327 M, # M = B * C
328 N,
329 C,
330 BLOCK_ROW_SIZE: tl.constexpr,
331 BLOCK_COL_SIZE: tl.constexpr,
332 HAS_WEIGHT_BIAS: tl.constexpr,
333):
334 pid = tl.program_id(0) * BLOCK_ROW_SIZE + tl.arange(0, BLOCK_ROW_SIZE)[:, None]
335 c_offsets = pid % C
336 row_mask = pid < M
337 dY += pid * N
338 X += pid * N
339 dX += pid * N
340 Mean += pid
341 Rstd += pid
343 mean = tl.load(Mean, mask=row_mask, other=0.0).to(tl.float32)
344 rstd = tl.load(Rstd, mask=row_mask, other=1.0).to(tl.float32)
345 if HAS_WEIGHT_BIAS:
346 w = tl.load(W + c_offsets, mask=row_mask).to(tl.float32)
347 else:
348 w = 1
350 dx_part2 = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
351 dx_part3 = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
353 for off in range(0, N, BLOCK_COL_SIZE):
354 cols = off + tl.arange(0, BLOCK_COL_SIZE)
355 col_mask = cols[None, :] < N
356 mask = row_mask and col_mask
357 dy = tl.load(dY + cols[None, :], mask).to(tl.float32)
358 x = tl.load(X + cols[None, :], mask).to(tl.float32)
359 x = tl.where(mask, x - mean, 0.0)
360 x_hat = x * rstd
361 dx_hat = dy * w
362 dx_part2 += dx_hat
363 dx_part3 += dx_hat * x_hat
365 dx_2 = tl.sum(dx_part2, axis=1)[:, None]
366 dx_3 = tl.sum(dx_part3, axis=1)[:, None]
368 for off in range(0, N, BLOCK_COL_SIZE):
369 cols = off + tl.arange(0, BLOCK_COL_SIZE)
370 col_mask = cols[None, :] < N
371 mask = row_mask and col_mask
372 dy = tl.load(dY + cols[None, :], mask).to(tl.float32)
373 x = tl.load(X + cols[None, :], mask).to(tl.float32)
374 x = tl.where(mask, x - mean, 0.0)
375 x_hat = x * rstd
376 dx_hat = dy * w
377 dx = rstd * (dx_hat - (dx_2 + x_hat * dx_3) / N)
378 tl.store(dX + cols, dx, mask=mask)
381@libentry()
382@triton.autotune(
383 configs=runtime.get_tuned_config("instance_norm_weight_bias_backward"),
384 key=["N", "B", "C"],
385)
386@triton.jit
387def weight_bias_backward_kernel(
388 dY,
389 X,
390 Mean, # [B, C]
391 Rstd, # [B, C]
392 dW,
393 dB,
394 M,
395 N,
396 B,
397 C,
398 BLOCK_BATCH_SIZE: tl.constexpr,
399 BLOCK_COL_SIZE: tl.constexpr,
400):
401 cid = tl.program_id(0)[None]
402 cid = cid[:, None]
403 dW += cid
404 dB += cid
405 c_mask = cid < C
407 accW = tl.zeros([BLOCK_BATCH_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
408 accB = tl.zeros([BLOCK_BATCH_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
410 for b_off in range(0, B, BLOCK_BATCH_SIZE):
411 bid = b_off + tl.arange(0, BLOCK_BATCH_SIZE)[:, None]
412 mid = bid * C + cid
413 row_mask = bid < B
414 mean = tl.load(Mean + mid, mask=row_mask).to(tl.float32)
415 rstd = tl.load(Rstd + mid, mask=row_mask).to(tl.float32)
416 for off in range(0, N, BLOCK_COL_SIZE):
417 cols = off + tl.arange(0, BLOCK_COL_SIZE)
418 col_mask = cols[None, :] < N
419 mask = row_mask and col_mask
420 dy = tl.load(dY + mid * N + cols[None, :], mask).to(tl.float32)
421 x = tl.load(X + mid * N + cols[None, :], mask).to(tl.float32)
422 x = tl.where(mask, x - mean, 0.0)
423 x_hat = x * rstd
424 accW += dy * x_hat
425 accB += dy
426 dw = tl.sum(accW)
427 db = tl.sum(accB)
428 tl.store(dW, dw, mask=c_mask)
429 tl.store(dB, db, mask=c_mask)
432class InstanceNorm(torch.autograd.Function):
433 @staticmethod
434 def forward(
435 ctx,
436 x,
437 weight=None,
438 bias=None,
439 running_mean=None,
440 running_var=None,
441 use_input_stats=False,
442 momentum=0.1,
443 eps=1e-05,
444 cudnn_enable=False,
445 ):
446 logger.debug("GEMS INSTANCENORM FORWARD")
447 assert len(x.shape) in [
448 3,
449 4,
450 5,
451 ], f"x.shape should be [B, C, N] or [B, C, H, W] or [B, C, H, W, L], but got {x.shape}"
452 B, C = x.shape[:2]
453 N = math.prod(x.shape[2:])
454 M = x.numel() // N
456 x = x.contiguous()
457 weight = weight.contiguous() if weight is not None else None
458 bias = bias.contiguous() if bias is not None else None
459 y = torch.empty_like(x)
461 has_weight_bias = weight is not None
462 if has_weight_bias:
463 assert weight is not None and bias is not None
465 has_running_stats = running_mean is not None
466 if has_running_stats:
467 assert (
468 N > 1
469 ), f"Expected more than 1 spatial element when training, got input size {x.shape}"
470 assert (
471 running_mean is not None and running_var is not None
472 ), "running_mean and running_var should not both be None"
473 assert (
474 running_mean.shape == running_var.shape and running_mean.shape[0] == C
475 ), f"running_mean and running_var should have shape as {[C,]}"
476 assert (
477 running_mean.dtype == running_var.dtype
478 ), "running_mean and running_var should have the same dtype"
479 if not use_input_stats:
480 assert (
481 has_running_stats
482 ), "Expected running_mean and running_var to be defined when use_input_stats is False"
484 # NOTE: when the input is half-precision(either float16 or bfloat16)
485 # these statistical data saved for backward is in single precision
486 acc_type = get_accumulator_dtype(x.dtype)
487 mean = torch.empty(size=(B, C), dtype=acc_type, device=x.device)
488 rstd = torch.empty(size=(B, C), dtype=acc_type, device=x.device)
490 with torch_device_fn.device(x.device):
491 if use_input_stats:
492 if N <= 128:
493 TILE_N = triton.next_power_of_2(N)
494 TILE_M = triton.cdiv(1024, TILE_N)
495 grid = (triton.cdiv(M, TILE_M), 1, 1)
496 instance_norm_persistent_kernel_multiline[grid](
497 x,
498 y,
499 weight,
500 bias,
501 mean,
502 rstd,
503 M,
504 N,
505 C,
506 eps,
507 TILE_M,
508 TILE_N,
509 HAS_WEIGHT_BIAS=has_weight_bias,
510 )
511 elif N <= 4096:
512 TILE_N = triton.next_power_of_2(N)
513 grid = (M, 1, 1)
514 instance_norm_persistent_kernel[grid](
515 x,
516 y,
517 weight,
518 bias,
519 mean,
520 rstd,
521 M,
522 N,
523 C,
524 eps,
525 TILE_N,
526 HAS_WEIGHT_BIAS=has_weight_bias,
527 )
528 else:
529 grid = (M, 1, 1)
530 instance_norm_loop_kernel[grid](
531 x,
532 y,
533 weight,
534 bias,
535 mean,
536 rstd,
537 M,
538 N,
539 C,
540 eps,
541 HAS_WEIGHT_BIAS=has_weight_bias,
542 )
543 if has_running_stats and use_input_stats: # update running stats
544 grid = lambda meta: (
545 triton.cdiv(C, meta["BLOCK_CHANNEL_SIZE"]),
546 1,
547 1,
548 )
549 update_running_stats_kernel[grid](
550 mean,
551 rstd,
552 running_mean,
553 running_var,
554 momentum,
555 B,
556 C,
557 N,
558 eps,
559 )
560 else: # use running stats instead of input stats
561 TILE_N = triton.next_power_of_2(N)
562 grid = (M, 1, 1)
563 instance_norm_use_running_stats_kernel[grid](
564 x,
565 y,
566 weight,
567 bias,
568 running_mean,
569 running_var,
570 mean,
571 rstd,
572 M,
573 N,
574 C,
575 eps,
576 TILE_N,
577 HAS_WEIGHT_BIAS=has_weight_bias,
578 )
580 ctx.save_for_backward(x, weight, mean, rstd)
581 ctx.M = M
582 ctx.N = N
583 ctx.C = C
584 ctx.has_weight_bias = has_weight_bias
585 return y
587 @staticmethod
588 def backward(ctx, out_grad):
589 logger.debug("GEMS INSTANCENORM BACKWARD")
590 out_grad = out_grad.contiguous()
591 (x, weight, mean, rstd) = ctx.saved_tensors
592 M = ctx.M
593 N = ctx.N
594 C = ctx.C
595 B = M // C
597 with torch_device_fn.device(x.device):
598 in_grad = torch.empty_like(x)
599 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_ROW_SIZE"]), 1, 1)
600 instance_norm_backward_kernel[grid](
601 out_grad,
602 x,
603 weight,
604 mean,
605 rstd,
606 in_grad,
607 M,
608 N,
609 C,
610 HAS_WEIGHT_BIAS=ctx.has_weight_bias,
611 )
613 if ctx.has_weight_bias:
614 grid = lambda meta: (C, 1, 1)
615 weight_grad = torch.empty_like(weight)
616 bias_grad = torch.empty_like(weight)
617 weight_bias_backward_kernel[grid](
618 out_grad, x, mean, rstd, weight_grad, bias_grad, M, N, B, C
619 )
620 else:
621 weight_grad = None
622 bias_grad = None
623 return in_grad, weight_grad, bias_grad, None, None, None, None, None, None
626def instance_norm(
627 input: Tensor,
628 weight: Optional[Tensor] = None,
629 bias: Optional[Tensor] = None,
630 running_mean: Optional[Tensor] = None,
631 running_var: Optional[Tensor] = None,
632 use_input_stats: bool = True,
633 momentum: float = 0.1,
634 eps: float = 1e-5,
635 cudnn_enable: bool = False,
636) -> Tensor:
637 r"""Applies Instance Normalization for each channel in each data sample in a
638 batch.
639 Inputs:
640 input: input tensor of shape :math:`(N, C, *)`
641 weight: weight tensor of shape :math:`(C)`
642 bias: bias tensor of shape :math:`(C)`
643 running_mean: running mean tensor of shape :math:`(C)`
644 running_var: running variance tensor of shape :math:`(C)`
645 use_input_stats: whether to use the mean and variance of the input tensor
646 momentum: momentum value for the running mean and variance
647 eps: epsilon value for numerical stability
648 cudnn_enable: whether to use cudnn for normalization
649 Returns:
650 output tensor of shape :math:`(N, C, *)`
651 """
653 return InstanceNorm.apply(
654 input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps
655 )