Coverage for src/flag_gems/runtime/backend/_cambricon/ops/softmax.py: 0%

546 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-23 02:03 +0800

1import copy 

2import logging 

3import math 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems import runtime 

10from flag_gems.runtime import torch_device_fn 

11from flag_gems.utils import libentry, libtuner 

12 

13from ..utils import MAX_NRAM_SIZE, TOTAL_CORE_NUM 

14 

15logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

16MAX_N = 16384 

17 

18 

19def align(max_block): 

20 a = triton.next_power_of_2(max_block) 

21 return max_block if max_block == a else a // 2 

22 

23 

24def config_prune1(configs, named_args, **kwargs): 

25 M = named_args["M"] 

26 N = named_args["N"] 

27 K = named_args["K"] 

28 input = named_args["input_ptr"] 

29 configs_map = {} 

30 for config in configs: 

31 kw = config.kwargs 

32 TILE_K, TILE_N, num_warps, num_stages = ( 

33 kw["TILE_K"], 

34 kw["TILE_N"], 

35 config.num_warps, 

36 config.num_stages, 

37 ) 

38 if N < MAX_N: 

39 config = copy.deepcopy(config) 

40 TILE_N = config.kwargs["TILE_N"] = N 

41 k_per_core = math.ceil(K / max(TOTAL_CORE_NUM // M, 1)) 

42 TILE_K = config.kwargs["TILE_K"] = k_per_core 

43 num_stages = config.num_stages = 1 

44 key = (TILE_K, TILE_N, num_warps, num_stages) 

45 configs_map.setdefault(key, config) 

46 

47 config = copy.deepcopy(config) 

48 max_tile_k_without_pipe = MAX_NRAM_SIZE // 4 // (2 * TILE_N + 1) 

49 TILE_K = config.kwargs["TILE_K"] = align(max_tile_k_without_pipe) 

50 num_stages = config.num_stages = 1 

51 key = (TILE_K, TILE_N, num_warps, num_stages) 

52 configs_map.setdefault(key, config) 

53 

54 config = copy.deepcopy(config) 

55 max_tile_k_without_pipe = MAX_NRAM_SIZE // 4 // (3 * TILE_N + 1) 

56 if input.dtype == torch.float32: 

57 max_tile_k_without_pipe = MAX_NRAM_SIZE // 4 // (4 * TILE_N + 1) 

58 TILE_K = config.kwargs["TILE_K"] = align(max_tile_k_without_pipe) 

59 num_stages = config.num_stages = 3 

60 key = (TILE_K, TILE_N, num_warps, num_stages) 

61 configs_map.setdefault(key, config) 

62 else: 

63 key = (TILE_K, TILE_N, num_warps, num_stages) 

64 configs_map.setdefault(key, config) 

65 pruned_configs = [] 

66 for k, v in configs_map.items(): 

67 pruned_configs.append(v) 

68 extra_config = copy.deepcopy(pruned_configs[0]) 

69 extra_config.kwargs["TILE_K"] = 1 

70 extra_config.kwargs["TILE_N"] = N 

71 extra_config.num_warps = 1 

72 extra_config.num_stages = 3 

73 pruned_configs.append(extra_config) 

74 extra_config2 = copy.deepcopy(extra_config) 

75 extra_config2.num_stages = 1 

76 pruned_configs.append(extra_config2) 

77 return pruned_configs 

78 

79 

80def softmax_tile_mode_for_non_inner(M, N, K, TILE_N, TILE_K): 

81 one_tile_k = TILE_K * max(TOTAL_CORE_NUM // M, 1) >= K 

82 one_tile_n = TILE_N >= N 

83 if one_tile_n and one_tile_k: 

84 return 0 

85 elif one_tile_n and not one_tile_k: 

86 return 1 

87 else: 

88 return 2 

89 

90 

91@libentry() 

92@libtuner( 

93 configs=runtime.get_tuned_config("softmax_non_inner"), 

94 key=[ 

95 "N", 

96 "K", 

97 ], 

98 prune_configs_by={"early_config_prune": config_prune1}, 

99) 

100@triton.heuristics(runtime.get_heuristic_config("softmax_non_inner")) 

101@triton.jit 

102def softmax_kernel_non_inner( 

103 output_ptr, 

104 input_ptr, 

105 M, 

106 N, 

107 K, 

108 TILE_N: tl.constexpr, 

109 TILE_K: tl.constexpr, 

110 TILE_MODE: tl.constexpr, 

111): 

112 pid_m = tl.program_id(0) 

113 pid_k = tl.program_id(1) 

114 

115 p_k_num = tl.num_programs(axis=1) 

116 split_k = tl.cdiv(K, p_k_num) 

117 k_start = pid_k * split_k 

118 

119 if TILE_MODE == 0: 

120 n_offset = tl.arange(0, TILE_N) 

121 k_offset = pid_k * TILE_K + tl.arange(0, TILE_K) 

122 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :] 

123 mask = k_offset[None, :] < K 

124 input_ptrs = input_ptr + offset 

125 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32) 

126 row_minus_max = inp - tl.max(inp, axis=0)[None, :] 

127 numerator = tl.exp(row_minus_max) 

128 denominator = tl.sum(numerator, axis=0)[None, :] 

129 recip = 1.0 / denominator 

130 softmax_output = numerator * recip 

131 output_ptrs = output_ptr + offset 

132 tl.store(output_ptrs, softmax_output, mask=mask) 

133 elif TILE_MODE == 1: 

134 for k_idx in range(0, split_k, TILE_K): 

135 k_offset = k_start + k_idx + tl.arange(0, TILE_K) 

136 n_offset = tl.arange(0, TILE_N) 

137 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :] 

138 mask = k_offset[None, :] < K 

139 input_ptrs = input_ptr + offset 

140 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32) 

141 row_minus_max = inp - tl.max(inp, axis=0)[None, :] 

142 numerator = tl.exp(row_minus_max) 

143 denominator = tl.sum(numerator, axis=0)[None, :] 

144 recip = 1.0 / denominator 

145 softmax_output = numerator * recip 

146 output_ptrs = output_ptr + offset 

147 tl.store(output_ptrs, softmax_output, mask=mask) 

148 else: 

149 for k_idx in range(0, split_k, TILE_K): 

150 k_offset = k_start + k_idx + tl.arange(0, TILE_K) 

151 m = tl.full([TILE_N, TILE_K], value=float("-inf"), dtype=tl.float32) 

152 z = tl.full([TILE_N, TILE_K], value=0.0, dtype=tl.float32) 

153 # specialization does not improve performance inn this example, as tested 

154 for start_n in range(0, N, TILE_N): 

155 n_offset = start_n + tl.arange(0, TILE_N) 

156 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :] 

157 mask = (n_offset[:, None] < N) & (k_offset[None, :] < K) 

158 inp = tl.load(input_ptr + offset, mask=mask, other=-float("inf")).to( 

159 tl.float32 

160 ) 

161 m_new = tl.maximum(m, inp) 

162 all_neg_inf = m_new == float("-inf") 

163 z = tl.where( 

164 all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new) 

165 ) 

166 m = m_new 

167 m_reduced = tl.max(m, 0) # (TILE_K,) 

168 z = tl.sum(z * tl.exp(m - m_reduced[None, :]), 0) # (TILE_K, ) 

169 recip_z = 1.0 / z 

170 m = m_reduced 

171 # specialization does not improve performance inn this example, as tested 

172 for start_n in range(0, N, TILE_N): 

173 n_offset = start_n + tl.arange(0, TILE_N) 

174 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :] 

175 mask = (n_offset[:, None] < N) & (k_offset[None, :] < K) 

176 inp = tl.load(input_ptr + offset, mask=mask, other=-float("inf")).to( 

177 tl.float32 

178 ) 

179 o = tl.exp(inp - m[None, :]) * recip_z[None, :] 

180 tl.store(output_ptr + offset, o, mask=mask) 

181 

182 

183def config_prune2(configs, named_args, **kwargs): 

184 M = named_args["M"] 

185 N = named_args["N"] 

186 input = named_args["input_ptr"] 

187 configs_map = {} 

188 # When N is less than MAX_C_MLU_SOFTMAX_FORWARD, no reduction loops 

189 for config in configs: 

190 kw = config.kwargs 

191 BLOCK_M, BLOCK_N, num_warps, num_stages = ( 

192 kw["BLOCK_M"], 

193 kw["BLOCK_N"], 

194 config.num_warps, 

195 config.num_stages, 

196 ) 

197 if N < MAX_N: 

198 config = copy.deepcopy(config) 

199 BLOCK_N = config.kwargs["BLOCK_N"] = N 

200 m_per_core = math.ceil(M / TOTAL_CORE_NUM) 

201 BLOCK_M = config.kwargs["BLOCK_M"] = m_per_core 

202 num_stages = config.num_stages = 1 

203 key = (BLOCK_M, BLOCK_N, num_warps, num_stages) 

204 configs_map.setdefault(key, config) 

205 

206 config = copy.deepcopy(config) 

207 max_block_m_without_pipe = MAX_NRAM_SIZE // 4 // (2 * BLOCK_N + 1) 

208 BLOCK_M = config.kwargs["BLOCK_M"] = align(max_block_m_without_pipe) 

209 num_stages = config.num_stages = 1 

210 key = (BLOCK_M, BLOCK_N, num_warps, num_stages) 

211 configs_map.setdefault(key, config) 

212 

213 config = copy.deepcopy(config) 

214 max_block_m_without_pipe = MAX_NRAM_SIZE // 4 // (4 * BLOCK_N + 1) 

215 if input.dtype == torch.float32: 

216 max_block_m_without_pipe = MAX_NRAM_SIZE // 4 // (6 * BLOCK_N + 1) 

217 BLOCK_M = config.kwargs["BLOCK_M"] = align(max_block_m_without_pipe) 

218 num_stages = config.num_stages = 3 

219 key = (BLOCK_M, BLOCK_N, num_warps, num_stages) 

220 configs_map.setdefault(key, config) 

221 key = (BLOCK_M, BLOCK_N, num_warps, num_stages) 

222 # Only keep one config for the same key 

223 configs_map.setdefault(key, config) 

224 pruned_configs = [] 

225 for k, v in configs_map.items(): 

226 pruned_configs.append(v) 

227 # Add a heuristic config. 

228 extra_config = copy.deepcopy(pruned_configs[0]) 

229 extra_config.kwargs["BLOCK_M"] = 1 

230 extra_config.kwargs["BLOCK_N"] = N 

231 extra_config.num_warps = 1 

232 extra_config.num_stages = 3 

233 pruned_configs.append(extra_config) 

234 extra_config2 = copy.deepcopy(extra_config) 

235 extra_config2.num_stages = 1 

236 pruned_configs.append(extra_config2) 

237 return pruned_configs 

238 

239 

240def softmax_tile_mode_for_inner(args): 

241 one_tile_m = args["BLOCK_M"] * TOTAL_CORE_NUM >= args["M"] 

242 one_tile_n = args["BLOCK_N"] >= args["N"] 

243 if one_tile_n and one_tile_m: 

244 return 0 

245 elif one_tile_n and not one_tile_m: 

246 return 1 

247 else: 

248 return 2 

249 

250 

251@libentry() 

252@libtuner( 

253 configs=runtime.get_tuned_config("softmax_inner"), 

254 key=[ 

255 "M", 

256 "N", 

257 ], 

258 prune_configs_by={"early_config_prune": config_prune2}, 

259) 

260@triton.heuristics(runtime.get_heuristic_config("softmax_inner")) 

261@triton.jit 

262def softmax_kernel_inner( 

263 output_ptr, 

264 input_ptr, 

265 M, 

266 N, 

267 BLOCK_M: tl.constexpr, 

268 BLOCK_N: tl.constexpr, 

269 TILE_MODE: tl.constexpr, 

270): 

271 pid_m = tl.program_id(0) 

272 pnum = tl.num_programs(axis=0) 

273 split_m = tl.cdiv(M, pnum) 

274 m_start = pid_m * split_m 

275 

276 if TILE_MODE == 0: 

277 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

278 n_offset = tl.arange(0, BLOCK_N) 

279 offset = m_offset[:, None] * N + n_offset[None, :] 

280 mask = m_offset[:, None] < M 

281 input_ptrs = input_ptr + offset 

282 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32) 

283 row_minus_max = inp - tl.max(inp, axis=1)[:, None] 

284 numerator = tl.exp(row_minus_max) 

285 denominator = tl.sum(numerator, axis=1)[:, None] 

286 recip = 1.0 / denominator 

287 softmax_output = numerator * recip 

288 output_ptrs = output_ptr + offset 

289 tl.store(output_ptrs, softmax_output, mask=mask) 

290 elif TILE_MODE == 1: 

291 for m_idx in range(0, split_m, BLOCK_M): 

292 m_offset = m_start + m_idx + tl.arange(0, BLOCK_M) 

293 n_offset = tl.arange(0, BLOCK_N) 

294 offset = m_offset[:, None] * N + n_offset[None, :] 

295 mask = m_offset[:, None] < M 

296 input_ptrs = input_ptr + offset 

297 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32) 

298 trans_inp = tl.trans(inp) 

299 row_minus_max = trans_inp - tl.max(trans_inp, axis=0)[None, :] 

300 numerator = tl.exp(row_minus_max) 

301 denominator = tl.sum(numerator, axis=0)[None, :] 

302 recip = 1.0 / denominator 

303 softmax_output = tl.trans(numerator * recip) 

304 output_ptrs = output_ptr + offset 

305 tl.store(output_ptrs, softmax_output, mask=mask) 

306 else: 

307 for m_idx in range(0, split_m, BLOCK_M): 

308 m_offset = m_start + m_idx + tl.arange(0, BLOCK_M) 

309 block_max = tl.full( 

310 [BLOCK_M, BLOCK_N], value=float("-inf"), dtype=tl.float32 

311 ) 

312 block_sum = tl.full([BLOCK_M, BLOCK_N], value=0.0, dtype=tl.float32) 

313 # specialization does not improve performance inn this example, as tested 

314 for start_n in range(0, N, BLOCK_N): 

315 n_offset = start_n + tl.arange(0, BLOCK_N) 

316 offset = m_offset[:, None] * N + n_offset[None, :] 

317 mask = m_offset[:, None] < M and n_offset[None, :] < N 

318 inp = tl.load(input_ptr + offset, mask=mask, other=-float("inf")).to( 

319 tl.float32 

320 ) 

321 cur_max = tl.maximum(block_max, inp) 

322 all_neg_inf = cur_max == float("-inf") 

323 block_sum = tl.where( 

324 all_neg_inf, 

325 block_sum, 

326 block_sum * tl.exp(block_max - cur_max) + tl.exp(inp - cur_max), 

327 ) 

328 block_max = cur_max 

329 

330 trans_block_max = tl.trans(block_max) 

331 trans_block_sum = tl.trans(block_sum) 

332 max_reduced = tl.max(trans_block_max, 0) 

333 total_sum = tl.sum( 

334 trans_block_sum * tl.exp(trans_block_max - max_reduced[None, :]), 0 

335 ) 

336 recip_total_sum = 1.0 / total_sum 

337 total_max = max_reduced 

338 

339 for start_n in range(0, N, BLOCK_N): 

340 n_offset = start_n + tl.arange(0, BLOCK_N) 

341 offset = m_offset[:, None] * N + n_offset[None, :] 

342 mask = m_offset[:, None] < M and n_offset[None, :] < N 

343 inp = tl.load(input_ptr + offset, mask=mask, other=-float("inf")).to( 

344 tl.float32 

345 ) 

346 o = tl.exp(inp - total_max[:, None]) * recip_total_sum[:, None] 

347 tl.store(output_ptr + offset, o, mask=mask) 

348 

349 

350@triton.jit 

351def softmax_kernel_inner_k_partial_stats( 

352 x_ptr, 

353 max_buf_ptr, 

354 sum_buf_ptr, 

355 M, 

356 N, 

357 T, 

358 BLOCK_M: tl.constexpr, 

359 BLOCK_N: tl.constexpr, 

360): 

361 pnum = tl.num_programs(axis=0) 

362 pid = tl.program_id(0) 

363 total_blocks = (M // BLOCK_M) * T 

364 work_per_core = (total_blocks + pnum - 1) // pnum 

365 start = pid * work_per_core 

366 end = tl.minimum(start + work_per_core, total_blocks) 

367 

368 for task in range(start, end): 

369 row_id = task // T 

370 tile_id = task % T 

371 

372 offs_m = row_id * BLOCK_M + tl.arange(0, BLOCK_M) 

373 offs_n = tile_id * BLOCK_N + tl.arange(0, BLOCK_N) 

374 mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) 

375 

376 tile = tl.load( 

377 x_ptr + offs_m[:, None] * N + offs_n[None, :], 

378 mask=mask, 

379 other=-float("inf"), 

380 ).to(tl.float32) 

381 

382 tile_max = tl.max(tile, axis=1) 

383 all_neg_inf = tile_max == -float("inf") 

384 

385 tile_sum = tl.where( 

386 all_neg_inf, 

387 0.0, 

388 tl.sum(tl.exp(tile - tile_max[:, None]), axis=1), 

389 ) 

390 

391 tl.store(max_buf_ptr + offs_m * T + tile_id, tile_max, mask=(offs_m < M)) 

392 tl.store(sum_buf_ptr + offs_m * T + tile_id, tile_sum, mask=(offs_m < M)) 

393 

394 

395@triton.jit 

396def softmax_kernel_inner_k_merge_stats( 

397 max_buf_ptr, 

398 sum_buf_ptr, 

399 gmax_ptr, 

400 gsum_ptr, 

401 M: tl.constexpr, 

402 T: tl.constexpr, 

403 BLOCK_M: tl.constexpr, 

404): 

405 pid_m = tl.program_id(axis=0) 

406 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # [BM] 

407 mask_m = offs_m < M 

408 tile_max = tl.load( 

409 max_buf_ptr + offs_m[:, None] * T + tl.arange(0, T)[None, :], 

410 mask=(offs_m[:, None] < M), 

411 other=-float("inf"), 

412 ) 

413 tile_sum = tl.load( 

414 sum_buf_ptr + offs_m[:, None] * T + tl.arange(0, T)[None, :], 

415 mask=(offs_m[:, None] < M), 

416 other=0.0, 

417 ).to(tl.float32) 

418 

419 gmax = tl.max(tile_max, axis=1) 

420 scale = tl.exp(tile_max - gmax[:, None]) 

421 scale = tl.where(gmax[:, None] == -float("inf"), 0.0, scale) 

422 gsum = tl.sum(tile_sum * scale, axis=1) 

423 

424 tl.store(gmax_ptr + offs_m, gmax, mask=mask_m) 

425 tl.store(gsum_ptr + offs_m, gsum, mask=mask_m) 

426 

427 

428@triton.jit 

429def softmax_kernel_inner_k_write_softmax( 

430 x_ptr, 

431 y_ptr, 

432 gmax_ptr, 

433 gsum_ptr, 

434 M, 

435 N, 

436 T, 

437 BLOCK_M: tl.constexpr, 

438 BLOCK_N: tl.constexpr, 

439): 

440 pnum = tl.num_programs(axis=0) 

441 pid = tl.program_id(0) 

442 total_blocks = (M // BLOCK_M) * T 

443 work_per_core = (total_blocks + pnum - 1) // pnum 

444 start = pid * work_per_core 

445 end = tl.minimum(start + work_per_core, total_blocks) 

446 

447 for task in range(start, end): 

448 row_id = task // T 

449 tile_id = task % T 

450 

451 offs_m = row_id * BLOCK_M + tl.arange(0, BLOCK_M) 

452 offs_n = tile_id * BLOCK_N + tl.arange(0, BLOCK_N) 

453 mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) 

454 

455 # load global stats 

456 gmax = tl.load(gmax_ptr + offs_m, mask=(offs_m < M), other=-float("inf")).to( 

457 tl.float32 

458 ) 

459 gsum = tl.load(gsum_ptr + offs_m, mask=(offs_m < M), other=0.0).to(tl.float32) 

460 

461 # load tile 

462 tile = tl.load( 

463 x_ptr + offs_m[:, None] * N + offs_n[None, :], 

464 mask=mask, 

465 other=-float("inf"), 

466 ).to(tl.float32) 

467 

468 valid = gsum[:, None] > 0 

469 

470 out = tl.where( 

471 valid, 

472 tl.exp(tile - gmax[:, None]) / gsum[:, None], 

473 0.0, 

474 ) 

475 

476 tl.store(y_ptr + offs_m[:, None] * N + offs_n[None, :], out, mask=mask) 

477 

478 

479# ------------------------ backward ------------------------------- 

480 

481 

482def nram_usage_for_backward_non_inner(bn, bk, tile_mode, num_stages, dtype): 

483 coef = 1 

484 if tile_mode == 0: 

485 coef = 3 

486 elif tile_mode == 1: 

487 if num_stages == 1: 

488 coef = 3 

489 else: 

490 if dtype == torch.float32: 

491 coef = 7 

492 else: 

493 coef = 6 

494 else: 

495 if num_stages == 1: 

496 coef = 5 

497 else: 

498 if dtype == torch.float32: 

499 coef = 13 

500 else: 

501 coef = 10 

502 return (coef * bn + 1) * bk * 4 

503 

504 

505def config_prune3(configs, named_args, **kwargs): 

506 M = named_args["M"] 

507 N = named_args["N"] 

508 K = named_args["K"] 

509 output = named_args["output_ptr"] 

510 dtype = output.dtype 

511 k_per_core = math.ceil(K / max(TOTAL_CORE_NUM // M, 1)) 

512 # No need for any loop. 

513 if nram_usage_for_backward_non_inner(N, k_per_core, 0, 1, dtype) < MAX_NRAM_SIZE: 

514 config = copy.deepcopy(configs[0]) 

515 config.kwargs["TILE_K"] = k_per_core 

516 config.kwargs["TILE_N"] = N 

517 config.num_stages = 1 

518 return [config] 

519 align_num = 256 // 4 if dtype == torch.float32 else 256 // 2 

520 pruned_configs = [] 

521 for config in configs: 

522 kw = config.kwargs 

523 TILE_K, TILE_N, num_stages = ( 

524 kw["TILE_K"], 

525 kw["TILE_N"], 

526 config.num_stages, 

527 ) 

528 # Align the lowest dimension to 256B while loading/storing data. 

529 if TILE_K % align_num != 0: 

530 continue 

531 # nram usage shoule be smaller than MAX_NRAM_SIZE 

532 mode = softmax_tile_mode_for_non_inner(M, N, K, TILE_N, TILE_K) 

533 nram = nram_usage_for_backward_non_inner( 

534 TILE_N, TILE_K, mode, num_stages, dtype 

535 ) 

536 if nram > MAX_NRAM_SIZE or nram < MAX_NRAM_SIZE // 2: 

537 continue 

538 pruned_configs.append(config) 

539 return pruned_configs 

540 

541 

542@libentry() 

543@libtuner( 

544 configs=runtime.get_tuned_config("softmax_non_inner_bw"), 

545 key=[ 

546 "N", 

547 "K", 

548 ], 

549 prune_configs_by={"early_config_prune": config_prune3}, 

550) 

551@triton.heuristics(runtime.get_heuristic_config("softmax_backward_non_inner")) 

552@triton.jit 

553def softmax_backward_kernel_non_inner( 

554 output_ptr, 

555 out_grad_ptr, 

556 in_grad_ptr, 

557 M, 

558 N, 

559 K, 

560 TILE_N: tl.constexpr, 

561 TILE_K: tl.constexpr, 

562 TILE_MODE: tl.constexpr, 

563): 

564 pid_m = tl.program_id(0) 

565 pid_k = tl.program_id(1) 

566 

567 p_k_num = tl.num_programs(axis=1) 

568 split_k = tl.cdiv(K, p_k_num) 

569 k_start = pid_k * split_k 

570 

571 if TILE_MODE == 0: 

572 n_offset = tl.arange(0, TILE_N) 

573 k_offset = pid_k * TILE_K + tl.arange(0, TILE_K) 

574 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :] 

575 mask = k_offset[None, :] < K 

576 out_tile = tl.load(output_ptr + offset, mask=mask).to(tl.float32) 

577 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32) 

578 scale = tl.sum(out_tile * out_grad_tile, axis=0) 

579 in_grad_tile = out_tile * (out_grad_tile - scale[None, :]) 

580 tl.store(in_grad_ptr + offset, in_grad_tile, mask=mask) 

581 elif TILE_MODE == 1: 

582 for k_idx in range(0, split_k, TILE_K): 

583 k_offset = k_start + k_idx + tl.arange(0, TILE_K) 

584 n_offset = tl.arange(0, TILE_N) 

585 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :] 

586 mask = k_offset[None, :] < K and n_offset[:, None] < N 

587 out_tile = tl.load(output_ptr + offset, mask=mask).to(tl.float32) 

588 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32) 

589 scale = tl.sum(out_tile * out_grad_tile, axis=0) 

590 in_grad_tile = out_tile * (out_grad_tile - scale[None, :]) 

591 tl.store(in_grad_ptr + offset, in_grad_tile, mask=mask) 

592 else: 

593 for k_idx in range(0, split_k, TILE_K): 

594 k_offset = k_start + k_idx + tl.arange(0, TILE_K) 

595 scale = tl.zeros([TILE_N, TILE_K], dtype=tl.float32) 

596 # specialization does not improve performance inn this example, as tested 

597 for start_n in range(0, N, TILE_N): 

598 n_offset = start_n + tl.arange(0, TILE_N) 

599 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :] 

600 mask = (n_offset[:, None] < N) & (k_offset[None, :] < K) 

601 out_tile = tl.load(output_ptr + offset, mask=mask).to(tl.float32) 

602 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32) 

603 scale += out_tile * out_grad_tile 

604 scale = tl.sum(scale, axis=0) 

605 for start_n in range(0, N, TILE_N): 

606 n_offset = start_n + tl.arange(0, TILE_N) 

607 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :] 

608 mask = (n_offset[:, None] < N) & (k_offset[None, :] < K) 

609 out_tile = tl.load(output_ptr + offset, mask=mask).to(tl.float32) 

610 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32) 

611 in_grad_tile = out_tile * (out_grad_tile - scale[None, :]) 

612 tl.store(in_grad_ptr + offset, in_grad_tile, mask=mask) 

613 

614 

615def config_prune4(configs, named_args, **kwargs): 

616 M = named_args["M"] 

617 N = named_args["N"] 

618 output = named_args["output_ptr"] 

619 configs_map = {} 

620 # When N is less than MAX_C_MLU_SOFTMAX_FORWARD, no reduction loops 

621 for config in configs: 

622 kw = config.kwargs 

623 BLOCK_M, BLOCK_N, num_warps, num_stages = ( 

624 kw["BLOCK_M"], 

625 kw["BLOCK_N"], 

626 config.num_warps, 

627 config.num_stages, 

628 ) 

629 if N < MAX_N: 

630 config = copy.deepcopy(config) 

631 BLOCK_N = config.kwargs["BLOCK_N"] = N 

632 m_per_core = math.ceil(M / TOTAL_CORE_NUM) 

633 BLOCK_M = config.kwargs["BLOCK_M"] = m_per_core 

634 num_stages = config.num_stages = 1 

635 key = (BLOCK_M, BLOCK_N, num_warps, num_stages) 

636 configs_map.setdefault(key, config) 

637 

638 config = copy.deepcopy(config) 

639 max_block_m_without_pipe = MAX_NRAM_SIZE // 4 // (3 * BLOCK_N + 1) 

640 BLOCK_M = config.kwargs["BLOCK_M"] = align(max_block_m_without_pipe) 

641 num_stages = config.num_stages = 1 

642 key = (BLOCK_M, BLOCK_N, num_warps, num_stages) 

643 configs_map.setdefault(key, config) 

644 

645 config = copy.deepcopy(config) 

646 max_block_m_without_pipe = MAX_NRAM_SIZE // 4 // (6 * BLOCK_N + 1) 

647 if output.dtype == torch.float32: 

648 max_block_m_without_pipe = MAX_NRAM_SIZE // 4 // (7 * BLOCK_N + 1) 

649 BLOCK_M = config.kwargs["BLOCK_M"] = align(max_block_m_without_pipe) 

650 num_stages = config.num_stages = 3 

651 key = (BLOCK_M, BLOCK_N, num_warps, num_stages) 

652 configs_map.setdefault(key, config) 

653 key = (BLOCK_M, BLOCK_N, num_warps, num_stages) 

654 # Only keep one config for the same key 

655 configs_map.setdefault(key, config) 

656 pruned_configs = [] 

657 for k, v in configs_map.items(): 

658 pruned_configs.append(v) 

659 # Add a heuristic config. 

660 extra_config = copy.deepcopy(pruned_configs[0]) 

661 extra_config.kwargs["BLOCK_M"] = 1 

662 extra_config.kwargs["BLOCK_N"] = N 

663 extra_config.num_warps = 1 

664 extra_config.num_stages = 3 

665 pruned_configs.append(extra_config) 

666 extra_config2 = copy.deepcopy(extra_config) 

667 extra_config2.num_stages = 1 

668 pruned_configs.append(extra_config2) 

669 return pruned_configs 

670 

671 

672@libentry() 

673@libtuner( 

674 configs=runtime.get_tuned_config("softmax_inner_bw"), 

675 key=[ 

676 "M", 

677 "N", 

678 ], 

679 prune_configs_by={"early_config_prune": config_prune4}, 

680) 

681@triton.heuristics( 

682 values=runtime.get_heuristic_config("softmax_backward_inner"), 

683) 

684@triton.jit 

685def softmax_backward_kernel_inner( 

686 output_ptr, 

687 out_grad_ptr, 

688 in_grad_ptr, 

689 M, 

690 N, 

691 BLOCK_M: tl.constexpr, 

692 BLOCK_N: tl.constexpr, 

693 TILE_MODE: tl.constexpr, 

694): 

695 pid_m = tl.program_id(0) 

696 pnum = tl.num_programs(axis=0) 

697 split_m = tl.cdiv(M, pnum) 

698 m_start = pid_m * split_m 

699 

700 if TILE_MODE == 0: 

701 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

702 n_offset = tl.arange(0, BLOCK_N) 

703 offset = m_offset[:, None] * N + n_offset[None, :] 

704 mask = m_offset[:, None] < M 

705 out_tile = tl.load(output_ptr + offset, mask=mask).to(tl.float32) 

706 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32) 

707 scale = tl.sum(out_tile * out_grad_tile, 1) 

708 in_grad_tile = out_tile * (out_grad_tile - scale[:, None]) 

709 tl.store(in_grad_ptr + offset, in_grad_tile, mask=mask) 

710 elif TILE_MODE == 1: 

711 for m_idx in range(0, split_m, BLOCK_M): 

712 m_offset = m_start + m_idx + tl.arange(0, BLOCK_M) 

713 n_offset = tl.arange(0, BLOCK_N) 

714 offset = m_offset[:, None] * N + n_offset[None, :] 

715 mask = m_offset[:, None] < M 

716 out_tile = tl.load(output_ptr + offset, mask=mask).to(tl.float32) 

717 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32) 

718 scale = tl.sum(out_tile * out_grad_tile, 1) 

719 in_grad_tile = out_tile * (out_grad_tile - scale[:, None]) 

720 tl.store(in_grad_ptr + offset, in_grad_tile, mask=mask) 

721 else: 

722 for m_idx in range(0, split_m, BLOCK_M): 

723 m_offset = m_start + m_idx + tl.arange(0, BLOCK_M) 

724 scale = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

725 for start_n in range(0, N, BLOCK_N): 

726 n_offset = start_n + tl.arange(0, BLOCK_N) 

727 offset = m_offset[:, None] * N + n_offset[None, :] 

728 mask = m_offset[:, None] < M and n_offset[None, :] < N 

729 out_tile = tl.load( 

730 output_ptr + offset, mask=mask, eviction_policy="evict_last" 

731 ).to(tl.float32) 

732 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32) 

733 scale += out_tile * out_grad_tile 

734 scale = tl.sum(scale, 1) 

735 for start_n in range(0, N, BLOCK_N): 

736 n_offset = start_n + tl.arange(0, BLOCK_N) 

737 offset = m_offset[:, None] * N + n_offset[None, :] 

738 mask = m_offset[:, None] < M and n_offset[None, :] < N 

739 out_tile = tl.load( 

740 output_ptr + offset, mask=mask, eviction_policy="evict_first" 

741 ).to(tl.float32) 

742 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32) 

743 in_grad_tile = out_tile * (out_grad_tile - scale[:, None]) 

744 tl.store(in_grad_ptr + offset, in_grad_tile, mask=mask) 

745 

746 

747def softmax(self, dim, half_to_float=False): 

748 logger.debug("GEMS_CAMBRICON SOFTMAX") 

749 

750 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim" 

751 dim = dim % self.ndim 

752 M = 1 

753 N = self.shape[dim] 

754 for i in range(dim): 

755 M *= self.shape[i] # pre_dim 

756 self = self.contiguous() 

757 if half_to_float: 

758 dtype = torch.float32 

759 else: 

760 dtype = self.dtype 

761 out = torch.empty_like(self, dtype=dtype) 

762 K = self.numel() // M // N # post_dim 

763 

764 with torch_device_fn.device(self.device): 

765 if K > 1: 

766 logger.debug("GEMS_CAMBRICON SOFTMAX USE NON INNER") 

767 grid = lambda meta: (M, max(TOTAL_CORE_NUM // M, 1), 1) 

768 softmax_kernel_non_inner[grid]( 

769 out, 

770 self, 

771 M, 

772 N, 

773 K, 

774 ) 

775 else: 

776 logger.debug("GEMS_CAMBRICON SOFTMAX USE INNER") 

777 if M > TOTAL_CORE_NUM or N < 1024 * 8 * 8: 

778 softmax_kernel_inner[TOTAL_CORE_NUM, 1, 1]( 

779 out, 

780 self, 

781 M, 

782 N, 

783 ) 

784 else: 

785 block_m = 1 

786 block_n = 8192 * 4 

787 if dtype is torch.float32: 

788 block_n = 8192 * 2 

789 # workspace 

790 T = (N + block_n - 1) // block_n 

791 max_buf = torch.empty((M, T), device=self.device, dtype=torch.float32) 

792 sum_buf = torch.empty((M, T), device=self.device, dtype=torch.float32) 

793 gmax = torch.empty((M,), device=self.device, dtype=torch.float32) 

794 gsum = torch.empty((M,), device=self.device, dtype=torch.float32) 

795 # kernel 1: per-tile stats 

796 softmax_kernel_inner_k_partial_stats[(TOTAL_CORE_NUM,)]( 

797 self, 

798 max_buf, 

799 sum_buf, 

800 M, 

801 N, 

802 T, 

803 BLOCK_M=block_m, 

804 BLOCK_N=block_n, 

805 bottleneck="simd", 

806 num_stages=3, 

807 ) 

808 # kernel 2: merge stats along N-tiles 

809 grid_merge = (triton.cdiv(M, block_m),) 

810 softmax_kernel_inner_k_merge_stats[grid_merge]( 

811 max_buf, sum_buf, gmax, gsum, M, T, BLOCK_M=block_m 

812 ) 

813 block_n = block_n // 2 

814 T = (N + block_n - 1) // block_n 

815 # kernel 3: write normalized outputs 

816 softmax_kernel_inner_k_write_softmax[(TOTAL_CORE_NUM,)]( 

817 self, 

818 out, 

819 gmax, 

820 gsum, 

821 M, 

822 N, 

823 T, 

824 BLOCK_M=block_m, 

825 BLOCK_N=block_n, 

826 bottleneck="simd", 

827 num_stages=3, 

828 ) 

829 return out 

830 

831 

832def softmax_backward(grad_output, output, dim, input_dtype): 

833 logger.debug("GEMS_CAMBRICON SOFTMAX VJP") 

834 

835 assert dim >= -output.ndim and dim < output.ndim, "Invalid dim" 

836 dim = dim % output.ndim 

837 M = 1 

838 N = output.shape[dim] 

839 for i in range(dim): 

840 M *= output.shape[i] 

841 

842 grad_output = grad_output.contiguous() 

843 in_grad = torch.empty_like(output) 

844 K = output.numel() // M // N 

845 

846 with torch_device_fn.device(in_grad.device): 

847 if K > 1: 

848 logger.debug("GEMS_CAMBRICON SOFTMAX VJP USE NON INNER") 

849 grid = lambda meta: (M, max(TOTAL_CORE_NUM // M, 1), 1) 

850 softmax_backward_kernel_non_inner[grid]( 

851 output, 

852 grad_output, 

853 in_grad, 

854 M, 

855 N, 

856 K, 

857 ) 

858 else: 

859 logger.debug("GEMS_CAMBRICON SOFTMAX VJP USE INNER") 

860 softmax_backward_kernel_inner[TOTAL_CORE_NUM, 1, 1]( 

861 output, 

862 grad_output, 

863 in_grad, 

864 M, 

865 N, 

866 ) 

867 return in_grad