Coverage for src/flag_gems/runtime/backend/_kunlunxin/fused/instance_norm.py: 0%

345 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1import logging 

2import math 

3from typing import Optional 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems import runtime 

10from flag_gems.runtime import torch_device_fn 

11from flag_gems.utils import libentry 

12from flag_gems.utils.type_utils import get_accumulator_dtype 

13 

14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

15Tensor = torch.Tensor 

16 

17 

18@triton.jit 

19def prev_multiple_of(a, b): 

20 # the largest x<a that x%b ==0 

21 return tl.cdiv(a, b) * b - b 

22 

23 

24@libentry() 

25@triton.autotune( 

26 configs=runtime.get_tuned_config("instancenorm"), 

27 key=["M", "N"], 

28) 

29@triton.jit(do_not_specialize=["eps"]) 

30def instance_norm_persistent_kernel( 

31 in_ptr, 

32 out_ptr, 

33 weight_ptr, 

34 bias_ptr, 

35 out_mean_ptr, # pointer to the mean 

36 out_rstd_ptr, # pointer to the 1/std 

37 M, # M = B * C 

38 N, 

39 C, 

40 eps, 

41 TILE_N: tl.constexpr, 

42 HAS_WEIGHT_BIAS: tl.constexpr, 

43): 

44 # using 1d tile makes code clean 

45 # Map the program id to the row of X and Y it should compute. 

46 pid = tl.program_id(0) 

47 m_mask = pid < M 

48 c_offsets = pid % C 

49 

50 n_offsets = tl.arange(0, TILE_N) 

51 mask = n_offsets < N 

52 

53 x = tl.load(in_ptr + pid * N + n_offsets, mask, other=0.0).to(tl.float32) 

54 m = tl.sum(x) / N 

55 d = x - m # deviation 

56 s = tl.where(mask, d * d, 0) 

57 sum_square = tl.sum(s) # sum of square of deviation 

58 var = sum_square / N 

59 rstd = tl.math.rsqrt(var + eps) 

60 

61 tl.store(out_mean_ptr + pid, m) 

62 tl.store(out_rstd_ptr + pid, rstd) 

63 

64 if HAS_WEIGHT_BIAS: 

65 w = tl.load(weight_ptr + c_offsets, mask=m_mask) 

66 b = tl.load(bias_ptr + c_offsets, mask=m_mask) 

67 out = (x - m) * rstd * w + b 

68 else: 

69 out = (x - m) * rstd 

70 

71 tl.store(out_ptr + pid * N + n_offsets, out, mask=mask) 

72 

73 

74@libentry() 

75# @triton.autotune( 

76# configs=runtime.get_tuned_config("instancenorm"), 

77# key=["M", "N"], 

78# ) 

79@triton.jit(do_not_specialize=["eps"]) 

80def instance_norm_persistent_kernel_multiline( 

81 in_ptr, 

82 out_ptr, 

83 weight_ptr, 

84 bias_ptr, 

85 out_mean_ptr, # pointer to the mean 

86 out_rstd_ptr, # pointer to the 1/std 

87 M, # M = B * C 

88 N, 

89 C, 

90 eps, 

91 TILE_M: tl.constexpr, 

92 TILE_N: tl.constexpr, 

93 HAS_WEIGHT_BIAS: tl.constexpr, 

94): 

95 # Map the program id to the row of X and Y it should compute. 

96 pid = tl.program_id(0) 

97 m_offsets = pid * TILE_M + tl.arange(0, TILE_M) 

98 m_mask = m_offsets < M 

99 c_offsets = m_offsets % C 

100 

101 n_offsets = tl.arange(0, TILE_N)[None, :] 

102 n_mask = n_offsets < N 

103 mask = m_mask[:, None] & n_mask 

104 

105 x = tl.load(in_ptr + m_offsets[:, None] * N + n_offsets, mask, other=0.0).to( 

106 tl.float32 

107 ) 

108 m = tl.sum(x, axis=1) / N 

109 d = x - m[:, None] # deviation 

110 s = tl.where(mask, d * d, 0) 

111 sum_square = tl.sum(s, axis=1) # sum of square of deviation 

112 var = sum_square / N 

113 rstd = tl.math.rsqrt(var + eps) 

114 

115 tl.store(out_mean_ptr + m_offsets, m, mask=m_mask) 

116 tl.store(out_rstd_ptr + m_offsets, rstd, mask=m_mask) 

117 

118 if HAS_WEIGHT_BIAS: 

119 w = tl.load(weight_ptr + c_offsets, mask=m_mask) 

120 b = tl.load(bias_ptr + c_offsets, mask=m_mask) 

121 out = (x - m[:, None]) * rstd[:, None] * w[:, None] + b[:, None] 

122 else: 

123 out = (x - m[:, None]) * rstd[:, None] 

124 

125 tl.store(out_ptr + m_offsets[:, None] * N + n_offsets, out, mask=mask) 

126 

127 

128def instance_norm_loop_kernel_heur_tile_n(args): 

129 return 8192 

130 import builtins 

131 

132 return builtins.min(args["N"], 8192) 

133 

134 

135@libentry() 

136# @triton.autotune( 

137# configs=runtime.get_tuned_config("instance_norm_loop"), 

138# key=["M", "N"], 

139# ) 

140@triton.heuristics( 

141 values={ 

142 "TILE_N": instance_norm_loop_kernel_heur_tile_n, 

143 }, 

144) 

145@triton.jit(do_not_specialize=["eps"]) 

146def instance_norm_loop_kernel( 

147 in_ptr, 

148 out_ptr, 

149 weight_ptr, 

150 bias_ptr, 

151 out_mean_ptr, # pointer to the mean 

152 out_rstd_ptr, # pointer to the 1/std 

153 M, # M = B * C 

154 N, 

155 C, 

156 eps, 

157 TILE_N: tl.constexpr, 

158 HAS_WEIGHT_BIAS: tl.constexpr, 

159): 

160 # Map the program id to the row of X and Y it should compute. 

161 pid = tl.program_id(0) 

162 m_mask = pid < M 

163 c_offsets = pid % C 

164 

165 # Compute mean 

166 m = tl.zeros((TILE_N,), dtype=tl.float32) # mean 

167 s = tl.zeros((TILE_N,), dtype=tl.float32) # sum((x - m)^2) 

168 cnt = tl.zeros((TILE_N,), dtype=tl.int32) 

169 num_steps = tl.cdiv(N, TILE_N) 

170 for step in range(0, num_steps - 1, 1): 

171 start_n = step * TILE_N 

172 n_offsets = start_n + tl.arange(0, TILE_N) 

173 x = tl.load(in_ptr + pid * N + n_offsets).to(tl.float32) 

174 new_m = m + (x - m) / (step + 1) 

175 new_s = s + (x - new_m) * (x - m) 

176 cnt += 1 

177 m = new_m 

178 s = new_s 

179 

180 # the last step 

181 for step in range(num_steps - 1, num_steps, 1): 

182 start_n = step * TILE_N 

183 n_offsets = start_n + tl.arange(0, TILE_N) 

184 mask = n_offsets < N 

185 x = tl.load(in_ptr + pid * N + n_offsets, mask=mask).to(tl.float32) 

186 new_m = tl.where(mask, m + (x - m) / (step + 1), m) 

187 new_s = tl.where(mask, s + (x - new_m) * (x - m), s) 

188 cnt += mask.to(tl.int32) 

189 m = new_m 

190 s = new_s 

191 

192 final_m = tl.sum(m * cnt) / N 

193 var = tl.sum(s + cnt * (m - final_m) * (m - final_m)) / N 

194 rstd = tl.math.rsqrt(var + eps) 

195 m = final_m 

196 # Write mean / rstd 

197 tl.store(out_mean_ptr + pid, m) 

198 tl.store(out_rstd_ptr + pid, rstd) 

199 

200 if HAS_WEIGHT_BIAS: 

201 w = tl.load(weight_ptr + c_offsets, mask=m_mask) 

202 b = tl.load(bias_ptr + c_offsets, mask=m_mask) 

203 else: 

204 w = 1 

205 b = 0 

206 

207 # reverse the order of the second sweep 

208 # Normalize and apply linear transformation 

209 prev_multiple = prev_multiple_of(N, TILE_N) 

210 # the first step, masking is needed 

211 for start_n in range(0, TILE_N, TILE_N): 

212 n_offsets = (prev_multiple - start_n) + tl.arange(0, TILE_N) 

213 mask = n_offsets < N 

214 x = tl.load( 

215 in_ptr + pid * N + n_offsets, 

216 mask=mask, 

217 other=0.0, 

218 eviction_policy="evict_first", 

219 ).to(tl.float32) 

220 out = w * (x - m) * rstd + b 

221 tl.store(out_ptr + pid * N + n_offsets, out, mask=mask) 

222 

223 for start_n in range(TILE_N, N, TILE_N): 

224 n_offsets = (prev_multiple - start_n) + tl.arange(0, TILE_N) 

225 x = tl.load(in_ptr + pid * N + n_offsets, eviction_policy="evict_first").to( 

226 tl.float32 

227 ) 

228 out = w * (x - m) * rstd + b 

229 tl.store(out_ptr + pid * N + n_offsets, out) 

230 

231 

232@libentry() 

233@triton.jit(do_not_specialize=["eps"]) 

234def instancenorm_fwd_kernel_xpu( 

235 X, 

236 Y, 

237 W, 

238 B, 

239 MEAN, 

240 RSTRD, 

241 M: tl.constexpr, 

242 N: tl.constexpr, 

243 C: tl.constexpr, 

244 eps: tl.constexpr, 

245 HAS_WEIGHT_BIAS: tl.constexpr, 

246 XBLOCK: tl.constexpr, 

247 RBLOCK: tl.constexpr, 

248): 

249 pid = tl.program_id(0) 

250 xoffset = pid * XBLOCK 

251 _xindex = xoffset + tl.arange(0, XBLOCK) 

252 xindex = _xindex[:, None] 

253 xmask = xindex < M 

254 rbase = tl.arange(0, RBLOCK)[None, :] 

255 _mean = tl.full([XBLOCK, RBLOCK], 0, tl.float32) 

256 _var = tl.full([XBLOCK, RBLOCK], 0, tl.float32) 

257 

258 for roffset in range(0, N, RBLOCK): 

259 rindex = roffset + rbase 

260 rmask = rindex < N 

261 x = tl.load(X + (rindex + (N * xindex)), rmask & xmask, other=0.0).to( 

262 tl.float32 

263 ) 

264 _mean = _mean + tl.broadcast_to(x, [XBLOCK, RBLOCK]) 

265 _var = _var + tl.broadcast_to(x * x, [XBLOCK, RBLOCK]) 

266 

267 mean = tl.sum(_mean, 1)[:, None] / N 

268 var = tl.sum(_var, 1)[:, None] / N 

269 var_mean = var - mean * mean 

270 rstd = 1 / tl.sqrt(var_mean + eps) 

271 

272 tl.store(MEAN + xindex, mean, xmask) 

273 tl.store(RSTRD + xindex, rstd, xmask) 

274 

275 cindex = xindex % C 

276 for roffset in range(0, N, RBLOCK): 

277 rindex = roffset + rbase 

278 rmask = rindex < N 

279 x = tl.load(X + (rindex + (N * xindex)), rmask & xmask, other=0.0).to( 

280 tl.float32 

281 ) 

282 if HAS_WEIGHT_BIAS: 

283 w = tl.load(W + cindex, xmask) 

284 b = tl.load(B + cindex, xmask) 

285 else: 

286 w = 1 

287 b = 0 

288 x_hat = (x - mean) * rstd 

289 y = x_hat * w + b 

290 tl.store(Y + (rindex + (N * xindex)), y, rmask & xmask) 

291 

292 

293def instance_norm_use_running_stats_kernel_heur_tile_n(args): 

294 return 8192 

295 import builtins 

296 

297 return builtins.min(args["N"], 8192) 

298 

299 

300@libentry() 

301# @triton.autotune( 

302# configs=runtime.get_tuned_config("instancenorm"), 

303# key=["M", "N"], 

304# ) 

305@triton.jit(do_not_specialize=["eps"]) 

306def instance_norm_use_running_stats_kernel( 

307 in_ptr, 

308 out_ptr, 

309 weight_ptr, 

310 bias_ptr, 

311 running_mean_ptr, # pointer to the mean 

312 running_var_ptr, # pointer to the var 

313 out_mean_ptr, # pointer to the mean 

314 out_rstd_ptr, # pointer to the 1/std 

315 M, # M = B * C 

316 N, 

317 C, 

318 eps, 

319 TILE_N: tl.constexpr, 

320 HAS_WEIGHT_BIAS: tl.constexpr, 

321): 

322 # using 1d tile makes code clean 

323 # Map the program id to the row of X and Y it should compute. 

324 pid = tl.program_id(0) 

325 m_mask = pid < M 

326 c_offsets = pid % C 

327 

328 n_offsets = tl.arange(0, TILE_N) 

329 mask = n_offsets < N 

330 

331 x = tl.load(in_ptr + pid * N + n_offsets, mask, other=0.0).to(tl.float32) 

332 m = tl.load(running_mean_ptr + c_offsets, mask=m_mask).to(tl.float32) 

333 var = tl.load(running_var_ptr + c_offsets, mask=m_mask).to(tl.float32) 

334 rstd = tl.math.rsqrt(var + eps) 

335 

336 tl.store(out_mean_ptr + pid, m) 

337 tl.store(out_rstd_ptr + pid, rstd) 

338 

339 if HAS_WEIGHT_BIAS: 

340 w = tl.load(weight_ptr + c_offsets, mask=m_mask).to(tl.float32) 

341 b = tl.load(bias_ptr + c_offsets, mask=m_mask).to(tl.float32) 

342 out = (x - m) * rstd * w + b 

343 else: 

344 out = (x - m) * rstd 

345 

346 tl.store(out_ptr + pid * N + n_offsets, out, mask=mask) 

347 

348 

349@triton.jit 

350def update_running_stats_kernel( 

351 mean_ptr, # pointer to the mean 

352 rstd_ptr, # pointer to the 1/std 

353 running_mean_ptr, 

354 running_var_ptr, 

355 momentum, 

356 B, 

357 C, 

358 N, 

359 eps, 

360 BLOCK_BATCH_SIZE: tl.constexpr = 1, 

361 BLOCK_CHANNEL_SIZE: tl.constexpr = 2048, 

362): 

363 cid = tl.program_id(0) * BLOCK_CHANNEL_SIZE + tl.arange(0, BLOCK_CHANNEL_SIZE) 

364 col_mask = cid < C 

365 running_mean = tl.load(running_mean_ptr + cid, mask=col_mask).to(tl.float32) 

366 running_var = tl.load(running_var_ptr + cid, mask=col_mask).to(tl.float32) 

367 

368 new_mean = tl.zeros((BLOCK_CHANNEL_SIZE,), dtype=tl.float32) 

369 new_var = tl.zeros((BLOCK_CHANNEL_SIZE,), dtype=tl.float32) 

370 for b in range(0, B, BLOCK_BATCH_SIZE): 

371 bid = b * BLOCK_BATCH_SIZE + tl.arange(0, BLOCK_BATCH_SIZE)[:, None] 

372 row_mask = bid < B 

373 mask = row_mask and col_mask[None, :] 

374 mean = tl.load(mean_ptr + bid * C + cid[None, :], mask=mask, other=0.0).to( 

375 tl.float32 

376 ) 

377 rstd = tl.load(rstd_ptr + bid * C + cid[None, :], mask=mask, other=0.0).to( 

378 tl.float32 

379 ) 

380 var = ( 

381 (1 / (rstd * rstd) + eps) * N / (N - 1) 

382 ) # NOTE: use unbiased var to update running_var 

383 

384 new_mean += tl.sum(mean, axis=0) 

385 new_var += tl.sum(var, axis=0) 

386 

387 new_running_mean = (1 - momentum) * running_mean + momentum * new_mean / B 

388 new_running_var = (1 - momentum) * running_var + momentum * new_var / B 

389 

390 tl.store(running_mean_ptr + cid, new_running_mean, mask=col_mask) 

391 tl.store(running_var_ptr + cid, new_running_var, mask=col_mask) 

392 

393 

394def instance_norm_backward_kernel_heur_block_row_size(args): 

395 return 1 

396 return triton.next_power_of_2(triton.cdiv(args["M"], 12)) # cluster_num 

397 

398 

399def instance_norm_backward_kernel_heur_block_col_size(args): 

400 import builtins 

401 

402 return builtins.min(triton.next_power_of_2(args["N"]), 8192) 

403 

404 

405@libentry() 

406# @triton.autotune( 

407# configs=runtime.get_tuned_config("instance_norm_backward"), 

408# key=["M", "N", "C"], 

409# ) 

410@triton.heuristics( 

411 values={ 

412 "BLOCK_ROW_SIZE": instance_norm_backward_kernel_heur_block_row_size, 

413 "BLOCK_COL_SIZE": instance_norm_backward_kernel_heur_block_col_size, 

414 }, 

415) 

416@triton.jit 

417def instance_norm_backward_kernel( 

418 dY, 

419 X, 

420 W, 

421 Mean, # [B, C] 

422 Rstd, # [B, C] 

423 dX, 

424 M, # M = B * C 

425 N, 

426 C, 

427 BLOCK_ROW_SIZE: tl.constexpr, 

428 BLOCK_COL_SIZE: tl.constexpr, 

429 HAS_WEIGHT_BIAS: tl.constexpr, 

430): 

431 pid = tl.program_id(0) * BLOCK_ROW_SIZE + tl.arange(0, BLOCK_ROW_SIZE)[:, None] 

432 c_offsets = pid % C 

433 row_mask = pid < M 

434 dY += pid * N 

435 X += pid * N 

436 dX += pid * N 

437 Mean += pid 

438 Rstd += pid 

439 

440 mean = tl.load(Mean, mask=row_mask, other=0.0).to(tl.float32) 

441 rstd = tl.load(Rstd, mask=row_mask, other=1.0).to(tl.float32) 

442 if HAS_WEIGHT_BIAS: 

443 w = tl.load(W + c_offsets, mask=row_mask).to(tl.float32) 

444 else: 

445 w = 1 

446 

447 dx_part2 = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

448 dx_part3 = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

449 

450 for off in range(0, N, BLOCK_COL_SIZE): 

451 cols = off + tl.arange(0, BLOCK_COL_SIZE) 

452 col_mask = cols[None, :] < N 

453 mask = row_mask and col_mask 

454 dy = tl.load(dY + cols[None, :], mask).to(tl.float32) 

455 x = tl.load(X + cols[None, :], mask).to(tl.float32) 

456 x = tl.where(mask, x - mean, 0.0) 

457 x_hat = x * rstd 

458 dx_hat = dy * w 

459 dx_part2 += dx_hat 

460 dx_part3 += dx_hat * x_hat 

461 

462 dx_2 = tl.sum(dx_part2, axis=1)[:, None] 

463 dx_3 = tl.sum(dx_part3, axis=1)[:, None] 

464 

465 for off in range(0, N, BLOCK_COL_SIZE): 

466 cols = off + tl.arange(0, BLOCK_COL_SIZE) 

467 col_mask = cols[None, :] < N 

468 mask = row_mask and col_mask 

469 dy = tl.load(dY + cols[None, :], mask).to(tl.float32) 

470 x = tl.load(X + cols[None, :], mask).to(tl.float32) 

471 x = tl.where(mask, x - mean, 0.0) 

472 x_hat = x * rstd 

473 dx_hat = dy * w 

474 dx = rstd * (dx_hat - (dx_2 + x_hat * dx_3) / N) 

475 tl.store(dX + cols, dx, mask=mask) 

476 

477 

478def weight_bias_backward_kernel_heur_block_batch_size(args): 

479 return 1 

480 import builtins 

481 

482 return builtins.min(triton.next_power_of_2(args["N"]), 8192) 

483 

484 

485def weight_bias_backward_kernel_heur_block_col_size(args): 

486 return triton.next_power_of_2(triton.cdiv(args["C"], 12)) # cluster_num 

487 

488 

489@libentry() 

490# @triton.autotune( 

491# configs=runtime.get_tuned_config("instance_norm_weight_bias_backward"), 

492# key=["N", "B", "C"], 

493# ) 

494@triton.heuristics( 

495 values={ 

496 "BLOCK_BATCH_SIZE": weight_bias_backward_kernel_heur_block_batch_size, 

497 "BLOCK_COL_SIZE": weight_bias_backward_kernel_heur_block_col_size, 

498 }, 

499) 

500@triton.jit 

501def weight_bias_backward_kernel( 

502 dY, 

503 X, 

504 Mean, # [B, C] 

505 Rstd, # [B, C] 

506 dW, 

507 dB, 

508 M, 

509 N, 

510 B, 

511 C, 

512 BLOCK_BATCH_SIZE: tl.constexpr, 

513 BLOCK_COL_SIZE: tl.constexpr, 

514): 

515 cid = tl.program_id(0)[:, None] 

516 dW += cid 

517 dB += cid 

518 c_mask = cid < C 

519 

520 accW = tl.zeros([BLOCK_BATCH_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

521 accB = tl.zeros([BLOCK_BATCH_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

522 

523 for b_off in range(0, B, BLOCK_BATCH_SIZE): 

524 bid = b_off + tl.arange(0, BLOCK_BATCH_SIZE)[:, None] 

525 mid = bid * C + cid 

526 row_mask = bid < B 

527 mean = tl.load(Mean + mid, mask=row_mask).to(tl.float32) 

528 rstd = tl.load(Rstd + mid, mask=row_mask).to(tl.float32) 

529 for off in range(0, N, BLOCK_COL_SIZE): 

530 cols = off + tl.arange(0, BLOCK_COL_SIZE) 

531 col_mask = cols[None, :] < N 

532 mask = row_mask and col_mask 

533 dy = tl.load(dY + mid * N + cols[None, :], mask).to(tl.float32) 

534 x = tl.load(X + mid * N + cols[None, :], mask).to(tl.float32) 

535 x = tl.where(mask, x - mean, 0.0) 

536 x_hat = x * rstd 

537 accW += dy * x_hat 

538 accB += dy 

539 dw = tl.sum(accW) 

540 db = tl.sum(accB) 

541 tl.store(dW, dw, mask=c_mask) 

542 tl.store(dB, db, mask=c_mask) 

543 

544 

545class InstanceNorm(torch.autograd.Function): 

546 @staticmethod 

547 def forward( 

548 ctx, 

549 x, 

550 weight=None, 

551 bias=None, 

552 running_mean=None, 

553 running_var=None, 

554 use_input_stats=False, 

555 momentum=0.1, 

556 eps=1e-05, 

557 cudnn_enable=False, 

558 ): 

559 logger.debug("GEMS INSTANCENORM FORWARD") 

560 assert len(x.shape) in [ 

561 3, 

562 4, 

563 5, 

564 ], f"x.shape should be [B, C, N] or [B, C, H, W] or [B, C, H, W, L], but got {x.shape}" 

565 B, C = x.shape[:2] 

566 N = math.prod(x.shape[2:]) 

567 M = x.numel() // N 

568 

569 x = x.contiguous() 

570 weight = weight.contiguous() if weight is not None else None 

571 bias = bias.contiguous() if bias is not None else None 

572 y = torch.empty_like(x) 

573 

574 has_weight_bias = weight is not None 

575 if has_weight_bias: 

576 assert weight is not None and bias is not None 

577 

578 has_running_stats = running_mean is not None 

579 if has_running_stats: 

580 assert ( 

581 N > 1 

582 ), f"Expected more than 1 spatial element when training, got input size {x.shape}" 

583 assert ( 

584 running_mean is not None and running_var is not None 

585 ), "running_mean and running_var should not both be None" 

586 assert ( 

587 running_mean.shape == running_var.shape and running_mean.shape[0] == C 

588 ), f"running_mean and running_var should have shape as {[C,]}" 

589 assert ( 

590 running_mean.dtype == running_var.dtype 

591 ), "running_mean and running_var should have the same dtype" 

592 if not use_input_stats: 

593 assert ( 

594 has_running_stats 

595 ), "Expected running_mean and running_var to be defined when use_input_stats is False" 

596 

597 # NOTE: when the input is half-precision(either float16 or bfloat16) 

598 # these statistical data saved for backward is in single precision 

599 acc_type = get_accumulator_dtype(x.dtype) 

600 mean = torch.empty(size=(B, C), dtype=acc_type, device=x.device) 

601 rstd = torch.empty(size=(B, C), dtype=acc_type, device=x.device) 

602 

603 with torch_device_fn.device(x.device): 

604 if use_input_stats: 

605 grid = (12, 1, 1) 

606 instancenorm_fwd_kernel_xpu[grid]( 

607 x, 

608 y, 

609 weight, 

610 bias, 

611 mean, 

612 rstd, 

613 M, 

614 N, 

615 C, 

616 eps, 

617 HAS_WEIGHT_BIAS=has_weight_bias, 

618 XBLOCK=triton.next_power_of_2(triton.cdiv(M, 12)), 

619 RBLOCK=8192, 

620 isCloseUnrollControl=True, 

621 buffer_size_limit=512, 

622 ) 

623 if has_running_stats and use_input_stats: # update running stats 

624 grid = lambda meta: ( 

625 triton.cdiv(C, meta["BLOCK_CHANNEL_SIZE"]), 

626 1, 

627 1, 

628 ) 

629 update_running_stats_kernel[grid]( 

630 mean, 

631 rstd, 

632 running_mean, 

633 running_var, 

634 momentum, 

635 B, 

636 C, 

637 N, 

638 eps, 

639 isCloseCoreTiling=True, 

640 isCloseVectorization=True, 

641 isCloseUnrollControl=True, 

642 ) 

643 else: # use running stats instead of input stats 

644 TILE_N = triton.next_power_of_2(N) 

645 grid = (M, 1, 1) 

646 instance_norm_use_running_stats_kernel[grid]( 

647 x, 

648 y, 

649 weight, 

650 bias, 

651 running_mean, 

652 running_var, 

653 mean, 

654 rstd, 

655 M, 

656 N, 

657 C, 

658 eps, 

659 TILE_N, 

660 HAS_WEIGHT_BIAS=has_weight_bias, 

661 isCloseUnrollControl=True, 

662 ) 

663 

664 ctx.save_for_backward(x, weight, mean, rstd) 

665 ctx.M = M 

666 ctx.N = N 

667 ctx.C = C 

668 ctx.has_weight_bias = has_weight_bias 

669 return y 

670 

671 @staticmethod 

672 def backward(ctx, out_grad): 

673 logger.debug("GEMS INSTANCENORM BACKWARD") 

674 out_grad = out_grad.contiguous() 

675 (x, weight, mean, rstd) = ctx.saved_tensors 

676 M = ctx.M 

677 N = ctx.N 

678 C = ctx.C 

679 B = M // C 

680 

681 with torch_device_fn.device(x.device): 

682 in_grad = torch.empty_like(x) 

683 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_ROW_SIZE"]), 1, 1) 

684 

685 instance_norm_backward_kernel[grid]( 

686 out_grad, 

687 x, 

688 weight, 

689 mean, 

690 rstd, 

691 in_grad, 

692 M, 

693 N, 

694 C, 

695 HAS_WEIGHT_BIAS=ctx.has_weight_bias, 

696 isCloseCoreTiling=True, 

697 ) 

698 

699 if ctx.has_weight_bias: 

700 grid = lambda meta: (C, 1, 1) 

701 weight_grad = torch.empty_like(weight) 

702 bias_grad = torch.empty_like(weight) 

703 weight_bias_backward_kernel[grid]( 

704 out_grad, 

705 x, 

706 mean, 

707 rstd, 

708 weight_grad, 

709 bias_grad, 

710 M, 

711 N, 

712 B, 

713 C, 

714 ) 

715 else: 

716 weight_grad = None 

717 bias_grad = None 

718 return in_grad, weight_grad, bias_grad, None, None, None, None, None, None 

719 

720 

721def instance_norm( 

722 input: Tensor, 

723 weight: Optional[Tensor] = None, 

724 bias: Optional[Tensor] = None, 

725 running_mean: Optional[Tensor] = None, 

726 running_var: Optional[Tensor] = None, 

727 use_input_stats: bool = True, 

728 momentum: float = 0.1, 

729 eps: float = 1e-5, 

730 cudnn_enable: bool = False, 

731) -> Tensor: 

732 r"""Applies Instance Normalization for each channel in each data sample in a 

733 batch. 

734 Inputs: 

735 input: input tensor of shape :math:`(N, C, *)` 

736 weight: weight tensor of shape :math:`(C)` 

737 bias: bias tensor of shape :math:`(C)` 

738 running_mean: running mean tensor of shape :math:`(C)` 

739 running_var: running variance tensor of shape :math:`(C)` 

740 use_input_stats: whether to use the mean and variance of the input tensor 

741 momentum: momentum value for the running mean and variance 

742 eps: epsilon value for numerical stability 

743 cudnn_enable: whether to use cudnn for normalization 

744 Returns: 

745 output tensor of shape :math:`(N, C, *)` 

746 """ 

747 

748 return InstanceNorm.apply( 

749 input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps 

750 )