Coverage for src/flag_gems/fused/instance_norm.py: 31%

308 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-13 10:08 +0800

1import logging 

2import math 

3from typing import Optional 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems import runtime 

10from flag_gems.runtime import torch_device_fn 

11from flag_gems.utils import libentry 

12from flag_gems.utils.type_utils import get_accumulator_dtype 

13 

14logger = logging.getLogger(__name__) 

15Tensor = torch.Tensor 

16 

17 

18@triton.jit 

19def prev_multiple_of(a, b): 

20 # the largest x<a that x%b ==0 

21 return tl.cdiv(a, b) * b - b 

22 

23 

24@libentry() 

25@triton.autotune( 

26 configs=runtime.get_tuned_config("instancenorm"), 

27 key=["M", "N"], 

28) 

29@triton.jit(do_not_specialize=["eps"]) 

30def instance_norm_persistent_kernel( 

31 in_ptr, 

32 out_ptr, 

33 weight_ptr, 

34 bias_ptr, 

35 out_mean_ptr, # pointer to the mean 

36 out_rstd_ptr, # pointer to the 1/std 

37 M, # M = B * C 

38 N, 

39 C, 

40 eps, 

41 TILE_N: tl.constexpr, 

42 HAS_WEIGHT_BIAS: tl.constexpr, 

43): 

44 # using 1d tile makes code clean 

45 # Map the program id to the row of X and Y it should compute. 

46 pid = tl.program_id(0) 

47 m_mask = pid < M 

48 c_offsets = pid % C 

49 

50 n_offsets = tl.arange(0, TILE_N) 

51 mask = n_offsets < N 

52 

53 x = tl.load(in_ptr + pid * N + n_offsets, mask, other=0.0).to(tl.float32) 

54 m = tl.sum(x) / N 

55 d = x - m # deviation 

56 s = tl.where(mask, d * d, 0) 

57 sum_square = tl.sum(s) # sum of square of deviation 

58 var = sum_square / N 

59 rstd = tl.math.rsqrt(var + eps) 

60 

61 tl.store(out_mean_ptr + pid, m) 

62 tl.store(out_rstd_ptr + pid, rstd) 

63 

64 if HAS_WEIGHT_BIAS: 

65 w = tl.load(weight_ptr + c_offsets, mask=m_mask) 

66 b = tl.load(bias_ptr + c_offsets, mask=m_mask) 

67 out = (x - m) * rstd * w + b 

68 else: 

69 out = (x - m) * rstd 

70 

71 tl.store(out_ptr + pid * N + n_offsets, out, mask=mask) 

72 

73 

74@libentry() 

75@triton.autotune( 

76 configs=runtime.get_tuned_config("instancenorm"), 

77 key=["M", "N"], 

78) 

79@triton.jit(do_not_specialize=["eps"]) 

80def instance_norm_persistent_kernel_multiline( 

81 in_ptr, 

82 out_ptr, 

83 weight_ptr, 

84 bias_ptr, 

85 out_mean_ptr, # pointer to the mean 

86 out_rstd_ptr, # pointer to the 1/std 

87 M, # M = B * C 

88 N, 

89 C, 

90 eps, 

91 TILE_M: tl.constexpr, 

92 TILE_N: tl.constexpr, 

93 HAS_WEIGHT_BIAS: tl.constexpr, 

94): 

95 # Map the program id to the row of X and Y it should compute. 

96 pid = tl.program_id(0) 

97 m_offsets = pid * TILE_M + tl.arange(0, TILE_M) 

98 m_mask = m_offsets < M 

99 c_offsets = m_offsets % C 

100 

101 n_offsets = tl.arange(0, TILE_N)[None, :] 

102 n_mask = n_offsets < N 

103 mask = m_mask[:, None] & n_mask 

104 

105 x = tl.load(in_ptr + m_offsets[:, None] * N + n_offsets, mask, other=0.0).to( 

106 tl.float32 

107 ) 

108 m = tl.sum(x, axis=1) / N 

109 d = x - m[:, None] # deviation 

110 s = tl.where(mask, d * d, 0) 

111 sum_square = tl.sum(s, axis=1) # sum of square of deviation 

112 var = sum_square / N 

113 rstd = tl.math.rsqrt(var + eps) 

114 

115 tl.store(out_mean_ptr + m_offsets, m, mask=m_mask) 

116 tl.store(out_rstd_ptr + m_offsets, rstd, mask=m_mask) 

117 

118 if HAS_WEIGHT_BIAS: 

119 w = tl.load(weight_ptr + c_offsets, mask=m_mask) 

120 b = tl.load(bias_ptr + c_offsets, mask=m_mask) 

121 out = (x - m[:, None]) * rstd[:, None] * w[:, None] + b[:, None] 

122 else: 

123 out = (x - m[:, None]) * rstd[:, None] 

124 

125 tl.store(out_ptr + m_offsets[:, None] * N + n_offsets, out, mask=mask) 

126 

127 

128@libentry() 

129@triton.autotune( 

130 configs=runtime.get_tuned_config("instance_norm_loop"), 

131 key=["M", "N"], 

132) 

133@triton.jit(do_not_specialize=["eps"]) 

134def instance_norm_loop_kernel( 

135 in_ptr, 

136 out_ptr, 

137 weight_ptr, 

138 bias_ptr, 

139 out_mean_ptr, # pointer to the mean 

140 out_rstd_ptr, # pointer to the 1/std 

141 M, # M = B * C 

142 N, 

143 C, 

144 eps, 

145 TILE_N: tl.constexpr, 

146 HAS_WEIGHT_BIAS: tl.constexpr, 

147): 

148 # Map the program id to the row of X and Y it should compute. 

149 pid = tl.program_id(0) 

150 m_mask = pid < M 

151 c_offsets = pid % C 

152 

153 # Compute mean 

154 m = tl.zeros((TILE_N,), dtype=tl.float32) # mean 

155 s = tl.zeros((TILE_N,), dtype=tl.float32) # sum((x - m)^2) 

156 cnt = tl.zeros((TILE_N,), dtype=tl.int32) 

157 num_steps = tl.cdiv(N, TILE_N) 

158 for step in range(0, num_steps - 1, 1): 

159 start_n = step * TILE_N 

160 n_offsets = start_n + tl.arange(0, TILE_N) 

161 x = tl.load(in_ptr + pid * N + n_offsets).to(tl.float32) 

162 new_m = m + (x - m) / (step + 1) 

163 new_s = s + (x - new_m) * (x - m) 

164 cnt += 1 

165 m = new_m 

166 s = new_s 

167 

168 # the last step 

169 for step in range(num_steps - 1, num_steps, 1): 

170 start_n = step * TILE_N 

171 n_offsets = start_n + tl.arange(0, TILE_N) 

172 mask = n_offsets < N 

173 x = tl.load(in_ptr + pid * N + n_offsets, mask=mask).to(tl.float32) 

174 new_m = tl.where(mask, m + (x - m) / (step + 1), m) 

175 new_s = tl.where(mask, s + (x - new_m) * (x - m), s) 

176 cnt += mask.to(tl.int32) 

177 m = new_m 

178 s = new_s 

179 

180 final_m = tl.sum(m * cnt) / N 

181 var = tl.sum(s + cnt * (m - final_m) * (m - final_m)) / N 

182 rstd = tl.math.rsqrt(var + eps) 

183 m = final_m 

184 # Write mean / rstd 

185 tl.store(out_mean_ptr + pid, m) 

186 tl.store(out_rstd_ptr + pid, rstd) 

187 

188 if HAS_WEIGHT_BIAS: 

189 w = tl.load(weight_ptr + c_offsets, mask=m_mask) 

190 b = tl.load(bias_ptr + c_offsets, mask=m_mask) 

191 else: 

192 w = 1 

193 b = 0 

194 

195 # reverse the order of the second sweep 

196 # Normalize and apply linear transformation 

197 prev_multiple = prev_multiple_of(N, TILE_N) 

198 # the first step, masking is needed 

199 for start_n in range(0, TILE_N, TILE_N): 

200 n_offsets = (prev_multiple - start_n) + tl.arange(0, TILE_N) 

201 mask = n_offsets < N 

202 x = tl.load( 

203 in_ptr + pid * N + n_offsets, 

204 mask=mask, 

205 other=0.0, 

206 eviction_policy="evict_first", 

207 ).to(tl.float32) 

208 out = w * (x - m) * rstd + b 

209 tl.store(out_ptr + pid * N + n_offsets, out, mask=mask) 

210 

211 for start_n in range(TILE_N, N, TILE_N): 

212 n_offsets = (prev_multiple - start_n) + tl.arange(0, TILE_N) 

213 x = tl.load(in_ptr + pid * N + n_offsets, eviction_policy="evict_first").to( 

214 tl.float32 

215 ) 

216 out = w * (x - m) * rstd + b 

217 tl.store(out_ptr + pid * N + n_offsets, out) 

218 

219 

220@libentry() 

221@triton.autotune( 

222 configs=runtime.get_tuned_config("instancenorm"), 

223 key=["M", "N"], 

224) 

225@triton.jit(do_not_specialize=["eps"]) 

226def instance_norm_use_running_stats_kernel( 

227 in_ptr, 

228 out_ptr, 

229 weight_ptr, 

230 bias_ptr, 

231 running_mean_ptr, # pointer to the mean 

232 running_var_ptr, # pointer to the var 

233 out_mean_ptr, # pointer to the mean 

234 out_rstd_ptr, # pointer to the 1/std 

235 M, # M = B * C 

236 N, 

237 C, 

238 eps, 

239 TILE_N: tl.constexpr, 

240 HAS_WEIGHT_BIAS: tl.constexpr, 

241): 

242 # using 1d tile makes code clean 

243 # Map the program id to the row of X and Y it should compute. 

244 pid = tl.program_id(0) 

245 m_mask = pid < M 

246 c_offsets = pid % C 

247 

248 n_offsets = tl.arange(0, TILE_N) 

249 mask = n_offsets < N 

250 

251 x = tl.load(in_ptr + pid * N + n_offsets, mask, other=0.0).to(tl.float32) 

252 m = tl.load(running_mean_ptr + c_offsets, mask=m_mask) 

253 var = tl.load(running_var_ptr + c_offsets, mask=m_mask) 

254 rstd = tl.math.rsqrt(var + eps) 

255 

256 tl.store(out_mean_ptr + pid, m) 

257 tl.store(out_rstd_ptr + pid, rstd) 

258 

259 if HAS_WEIGHT_BIAS: 

260 w = tl.load(weight_ptr + c_offsets, mask=m_mask) 

261 b = tl.load(bias_ptr + c_offsets, mask=m_mask) 

262 out = (x - m) * rstd * w + b 

263 else: 

264 out = (x - m) * rstd 

265 

266 tl.store(out_ptr + pid * N + n_offsets, out, mask=mask) 

267 

268 

269@triton.jit 

270def update_running_stats_kernel( 

271 mean_ptr, # pointer to the mean 

272 rstd_ptr, # pointer to the 1/std 

273 running_mean_ptr, 

274 running_var_ptr, 

275 momentum, 

276 B, 

277 C, 

278 N, 

279 eps, 

280 BLOCK_BATCH_SIZE: tl.constexpr = 1, 

281 BLOCK_CHANNEL_SIZE: tl.constexpr = 2048, 

282): 

283 cid = tl.program_id(0) * BLOCK_CHANNEL_SIZE + tl.arange(0, BLOCK_CHANNEL_SIZE) 

284 col_mask = cid < C 

285 running_mean = tl.load(running_mean_ptr + cid, mask=col_mask).to(tl.float32) 

286 running_var = tl.load(running_var_ptr + cid, mask=col_mask).to(tl.float32) 

287 

288 new_mean = tl.zeros((BLOCK_CHANNEL_SIZE,), dtype=tl.float32) 

289 new_var = tl.zeros((BLOCK_CHANNEL_SIZE,), dtype=tl.float32) 

290 for b in range(0, B, BLOCK_BATCH_SIZE): 

291 bid = b * BLOCK_BATCH_SIZE + tl.arange(0, BLOCK_BATCH_SIZE)[:, None] 

292 row_mask = bid < B 

293 mask = row_mask and col_mask[None, :] 

294 mean = tl.load(mean_ptr + bid * C + cid[None, :], mask=mask, other=0.0).to( 

295 tl.float32 

296 ) 

297 rstd = tl.load(rstd_ptr + bid * C + cid[None, :], mask=mask, other=0.0).to( 

298 tl.float32 

299 ) 

300 var = ( 

301 (1 / (rstd * rstd) + eps) * N / (N - 1) 

302 ) # NOTE: use unbiased var to update running_var 

303 

304 new_mean += tl.sum(mean, axis=0) 

305 new_var += tl.sum(var, axis=0) 

306 

307 new_running_mean = (1 - momentum) * running_mean + momentum * new_mean / B 

308 new_running_var = (1 - momentum) * running_var + momentum * new_var / B 

309 

310 tl.store(running_mean_ptr + cid, new_running_mean, mask=col_mask) 

311 tl.store(running_var_ptr + cid, new_running_var, mask=col_mask) 

312 

313 

314@libentry() 

315@triton.autotune( 

316 configs=runtime.get_tuned_config("instance_norm_backward"), 

317 key=["M", "N", "C"], 

318) 

319@triton.jit 

320def instance_norm_backward_kernel( 

321 dY, 

322 X, 

323 W, 

324 Mean, # [B, C] 

325 Rstd, # [B, C] 

326 dX, 

327 M, # M = B * C 

328 N, 

329 C, 

330 BLOCK_ROW_SIZE: tl.constexpr, 

331 BLOCK_COL_SIZE: tl.constexpr, 

332 HAS_WEIGHT_BIAS: tl.constexpr, 

333): 

334 pid = tl.program_id(0) * BLOCK_ROW_SIZE + tl.arange(0, BLOCK_ROW_SIZE)[:, None] 

335 c_offsets = pid % C 

336 row_mask = pid < M 

337 dY += pid * N 

338 X += pid * N 

339 dX += pid * N 

340 Mean += pid 

341 Rstd += pid 

342 

343 mean = tl.load(Mean, mask=row_mask, other=0.0).to(tl.float32) 

344 rstd = tl.load(Rstd, mask=row_mask, other=1.0).to(tl.float32) 

345 if HAS_WEIGHT_BIAS: 

346 w = tl.load(W + c_offsets, mask=row_mask).to(tl.float32) 

347 else: 

348 w = 1 

349 

350 dx_part2 = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

351 dx_part3 = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

352 

353 for off in range(0, N, BLOCK_COL_SIZE): 

354 cols = off + tl.arange(0, BLOCK_COL_SIZE) 

355 col_mask = cols[None, :] < N 

356 mask = row_mask and col_mask 

357 dy = tl.load(dY + cols[None, :], mask).to(tl.float32) 

358 x = tl.load(X + cols[None, :], mask).to(tl.float32) 

359 x = tl.where(mask, x - mean, 0.0) 

360 x_hat = x * rstd 

361 dx_hat = dy * w 

362 dx_part2 += dx_hat 

363 dx_part3 += dx_hat * x_hat 

364 

365 dx_2 = tl.sum(dx_part2, axis=1)[:, None] 

366 dx_3 = tl.sum(dx_part3, axis=1)[:, None] 

367 

368 for off in range(0, N, BLOCK_COL_SIZE): 

369 cols = off + tl.arange(0, BLOCK_COL_SIZE) 

370 col_mask = cols[None, :] < N 

371 mask = row_mask and col_mask 

372 dy = tl.load(dY + cols[None, :], mask).to(tl.float32) 

373 x = tl.load(X + cols[None, :], mask).to(tl.float32) 

374 x = tl.where(mask, x - mean, 0.0) 

375 x_hat = x * rstd 

376 dx_hat = dy * w 

377 dx = rstd * (dx_hat - (dx_2 + x_hat * dx_3) / N) 

378 tl.store(dX + cols, dx, mask=mask) 

379 

380 

381@libentry() 

382@triton.autotune( 

383 configs=runtime.get_tuned_config("instance_norm_weight_bias_backward"), 

384 key=["N", "B", "C"], 

385) 

386@triton.jit 

387def weight_bias_backward_kernel( 

388 dY, 

389 X, 

390 Mean, # [B, C] 

391 Rstd, # [B, C] 

392 dW, 

393 dB, 

394 M, 

395 N, 

396 B, 

397 C, 

398 BLOCK_BATCH_SIZE: tl.constexpr, 

399 BLOCK_COL_SIZE: tl.constexpr, 

400): 

401 cid = tl.program_id(0)[None] 

402 cid = cid[:, None] 

403 dW += cid 

404 dB += cid 

405 c_mask = cid < C 

406 

407 accW = tl.zeros([BLOCK_BATCH_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

408 accB = tl.zeros([BLOCK_BATCH_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

409 

410 for b_off in range(0, B, BLOCK_BATCH_SIZE): 

411 bid = b_off + tl.arange(0, BLOCK_BATCH_SIZE)[:, None] 

412 mid = bid * C + cid 

413 row_mask = bid < B 

414 mean = tl.load(Mean + mid, mask=row_mask).to(tl.float32) 

415 rstd = tl.load(Rstd + mid, mask=row_mask).to(tl.float32) 

416 for off in range(0, N, BLOCK_COL_SIZE): 

417 cols = off + tl.arange(0, BLOCK_COL_SIZE) 

418 col_mask = cols[None, :] < N 

419 mask = row_mask and col_mask 

420 dy = tl.load(dY + mid * N + cols[None, :], mask).to(tl.float32) 

421 x = tl.load(X + mid * N + cols[None, :], mask).to(tl.float32) 

422 x = tl.where(mask, x - mean, 0.0) 

423 x_hat = x * rstd 

424 accW += dy * x_hat 

425 accB += dy 

426 dw = tl.sum(accW) 

427 db = tl.sum(accB) 

428 tl.store(dW, dw, mask=c_mask) 

429 tl.store(dB, db, mask=c_mask) 

430 

431 

432class InstanceNorm(torch.autograd.Function): 

433 @staticmethod 

434 def forward( 

435 ctx, 

436 x, 

437 weight=None, 

438 bias=None, 

439 running_mean=None, 

440 running_var=None, 

441 use_input_stats=False, 

442 momentum=0.1, 

443 eps=1e-05, 

444 cudnn_enable=False, 

445 ): 

446 logger.debug("GEMS INSTANCENORM FORWARD") 

447 assert len(x.shape) in [ 

448 3, 

449 4, 

450 5, 

451 ], f"x.shape should be [B, C, N] or [B, C, H, W] or [B, C, H, W, L], but got {x.shape}" 

452 B, C = x.shape[:2] 

453 N = math.prod(x.shape[2:]) 

454 M = x.numel() // N 

455 

456 x = x.contiguous() 

457 weight = weight.contiguous() if weight is not None else None 

458 bias = bias.contiguous() if bias is not None else None 

459 y = torch.empty_like(x) 

460 

461 has_weight_bias = weight is not None 

462 if has_weight_bias: 

463 assert weight is not None and bias is not None 

464 

465 has_running_stats = running_mean is not None 

466 if has_running_stats: 

467 assert ( 

468 N > 1 

469 ), f"Expected more than 1 spatial element when training, got input size {x.shape}" 

470 assert ( 

471 running_mean is not None and running_var is not None 

472 ), "running_mean and running_var should not both be None" 

473 assert ( 

474 running_mean.shape == running_var.shape and running_mean.shape[0] == C 

475 ), f"running_mean and running_var should have shape as {[C,]}" 

476 assert ( 

477 running_mean.dtype == running_var.dtype 

478 ), "running_mean and running_var should have the same dtype" 

479 if not use_input_stats: 

480 assert ( 

481 has_running_stats 

482 ), "Expected running_mean and running_var to be defined when use_input_stats is False" 

483 

484 # NOTE: when the input is half-precision(either float16 or bfloat16) 

485 # these statistical data saved for backward is in single precision 

486 acc_type = get_accumulator_dtype(x.dtype) 

487 mean = torch.empty(size=(B, C), dtype=acc_type, device=x.device) 

488 rstd = torch.empty(size=(B, C), dtype=acc_type, device=x.device) 

489 

490 with torch_device_fn.device(x.device): 

491 if use_input_stats: 

492 if N <= 128: 

493 TILE_N = triton.next_power_of_2(N) 

494 TILE_M = triton.cdiv(1024, TILE_N) 

495 grid = (triton.cdiv(M, TILE_M), 1, 1) 

496 instance_norm_persistent_kernel_multiline[grid]( 

497 x, 

498 y, 

499 weight, 

500 bias, 

501 mean, 

502 rstd, 

503 M, 

504 N, 

505 C, 

506 eps, 

507 TILE_M, 

508 TILE_N, 

509 HAS_WEIGHT_BIAS=has_weight_bias, 

510 ) 

511 elif N <= 4096: 

512 TILE_N = triton.next_power_of_2(N) 

513 grid = (M, 1, 1) 

514 instance_norm_persistent_kernel[grid]( 

515 x, 

516 y, 

517 weight, 

518 bias, 

519 mean, 

520 rstd, 

521 M, 

522 N, 

523 C, 

524 eps, 

525 TILE_N, 

526 HAS_WEIGHT_BIAS=has_weight_bias, 

527 ) 

528 else: 

529 grid = (M, 1, 1) 

530 instance_norm_loop_kernel[grid]( 

531 x, 

532 y, 

533 weight, 

534 bias, 

535 mean, 

536 rstd, 

537 M, 

538 N, 

539 C, 

540 eps, 

541 HAS_WEIGHT_BIAS=has_weight_bias, 

542 ) 

543 if has_running_stats and use_input_stats: # update running stats 

544 grid = lambda meta: ( 

545 triton.cdiv(C, meta["BLOCK_CHANNEL_SIZE"]), 

546 1, 

547 1, 

548 ) 

549 update_running_stats_kernel[grid]( 

550 mean, 

551 rstd, 

552 running_mean, 

553 running_var, 

554 momentum, 

555 B, 

556 C, 

557 N, 

558 eps, 

559 ) 

560 else: # use running stats instead of input stats 

561 TILE_N = triton.next_power_of_2(N) 

562 grid = (M, 1, 1) 

563 instance_norm_use_running_stats_kernel[grid]( 

564 x, 

565 y, 

566 weight, 

567 bias, 

568 running_mean, 

569 running_var, 

570 mean, 

571 rstd, 

572 M, 

573 N, 

574 C, 

575 eps, 

576 TILE_N, 

577 HAS_WEIGHT_BIAS=has_weight_bias, 

578 ) 

579 

580 ctx.save_for_backward(x, weight, mean, rstd) 

581 ctx.M = M 

582 ctx.N = N 

583 ctx.C = C 

584 ctx.has_weight_bias = has_weight_bias 

585 return y 

586 

587 @staticmethod 

588 def backward(ctx, out_grad): 

589 logger.debug("GEMS INSTANCENORM BACKWARD") 

590 out_grad = out_grad.contiguous() 

591 (x, weight, mean, rstd) = ctx.saved_tensors 

592 M = ctx.M 

593 N = ctx.N 

594 C = ctx.C 

595 B = M // C 

596 

597 with torch_device_fn.device(x.device): 

598 in_grad = torch.empty_like(x) 

599 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_ROW_SIZE"]), 1, 1) 

600 instance_norm_backward_kernel[grid]( 

601 out_grad, 

602 x, 

603 weight, 

604 mean, 

605 rstd, 

606 in_grad, 

607 M, 

608 N, 

609 C, 

610 HAS_WEIGHT_BIAS=ctx.has_weight_bias, 

611 ) 

612 

613 if ctx.has_weight_bias: 

614 grid = lambda meta: (C, 1, 1) 

615 weight_grad = torch.empty_like(weight) 

616 bias_grad = torch.empty_like(weight) 

617 weight_bias_backward_kernel[grid]( 

618 out_grad, x, mean, rstd, weight_grad, bias_grad, M, N, B, C 

619 ) 

620 else: 

621 weight_grad = None 

622 bias_grad = None 

623 return in_grad, weight_grad, bias_grad, None, None, None, None, None, None 

624 

625 

626def instance_norm( 

627 input: Tensor, 

628 weight: Optional[Tensor] = None, 

629 bias: Optional[Tensor] = None, 

630 running_mean: Optional[Tensor] = None, 

631 running_var: Optional[Tensor] = None, 

632 use_input_stats: bool = True, 

633 momentum: float = 0.1, 

634 eps: float = 1e-5, 

635 cudnn_enable: bool = False, 

636) -> Tensor: 

637 r"""Applies Instance Normalization for each channel in each data sample in a 

638 batch. 

639 Inputs: 

640 input: input tensor of shape :math:`(N, C, *)` 

641 weight: weight tensor of shape :math:`(C)` 

642 bias: bias tensor of shape :math:`(C)` 

643 running_mean: running mean tensor of shape :math:`(C)` 

644 running_var: running variance tensor of shape :math:`(C)` 

645 use_input_stats: whether to use the mean and variance of the input tensor 

646 momentum: momentum value for the running mean and variance 

647 eps: epsilon value for numerical stability 

648 cudnn_enable: whether to use cudnn for normalization 

649 Returns: 

650 output tensor of shape :math:`(N, C, *)` 

651 """ 

652 

653 return InstanceNorm.apply( 

654 input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps 

655 )