Coverage for src/flag_gems/runtime/backend/_cambricon/ops/log_softmax.py: 0%

556 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-19 02:32 +0800

1import copy 

2import logging 

3import math 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems import runtime 

10from flag_gems.runtime import torch_device_fn 

11from flag_gems.utils import libentry, libtuner 

12 

13from ..utils import MAX_NRAM_SIZE, TOTAL_CORE_NUM 

14 

15logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

16MAX_N = 16384 

17 

18 

19def align(max_block): 

20 a = triton.next_power_of_2(max_block) 

21 return max_block if max_block == a else a // 2 

22 

23 

24def config_prune1(configs, named_args, **kwargs): 

25 M = named_args["M"] 

26 N = named_args["N"] 

27 K = named_args["K"] 

28 input = named_args["input_ptr"] 

29 configs_map = {} 

30 for config in configs: 

31 kw = config.kwargs 

32 TILE_K, TILE_N, num_warps, num_stages = ( 

33 kw["TILE_K"], 

34 kw["TILE_N"], 

35 config.num_warps, 

36 config.num_stages, 

37 ) 

38 if N < MAX_N: 

39 config = copy.deepcopy(config) 

40 TILE_N = config.kwargs["TILE_N"] = N 

41 k_per_core = math.ceil(K / max(TOTAL_CORE_NUM // M, 1)) 

42 nram_usage = (2 * TILE_N + 1) * k_per_core * 4 

43 if nram_usage < MAX_NRAM_SIZE: 

44 TILE_K = config.kwargs["TILE_K"] = k_per_core 

45 num_stages = config.num_stages = 1 

46 key = (TILE_K, TILE_N, num_warps, num_stages) 

47 configs_map.setdefault(key, config) 

48 else: 

49 max_tile_k_without_pipe = MAX_NRAM_SIZE // 4 // (2 * TILE_N + 1) 

50 TILE_K = config.kwargs["TILE_K"] = align(max_tile_k_without_pipe) 

51 num_stages = config.num_stages = 1 

52 key = (TILE_K, TILE_N, num_warps, num_stages) 

53 configs_map.setdefault(key, config) 

54 

55 config = copy.deepcopy(config) 

56 max_tile_k_without_pipe = MAX_NRAM_SIZE // 4 // (3 * TILE_N + 1) 

57 if input.dtype == torch.float32: 

58 max_tile_k_without_pipe = MAX_NRAM_SIZE // 4 // (4 * TILE_N + 1) 

59 TILE_K = config.kwargs["TILE_K"] = align(max_tile_k_without_pipe) 

60 num_stages = config.num_stages = 3 

61 key = (TILE_K, TILE_N, num_warps, num_stages) 

62 configs_map.setdefault(key, config) 

63 else: 

64 key = (TILE_K, TILE_N, num_warps, num_stages) 

65 configs_map.setdefault(key, config) 

66 pruned_configs = [] 

67 for k, v in configs_map.items(): 

68 pruned_configs.append(v) 

69 extra_config = copy.deepcopy(pruned_configs[0]) 

70 extra_config.kwargs["TILE_K"] = 1 

71 extra_config.kwargs["TILE_N"] = N 

72 extra_config.num_warps = 1 

73 extra_config.num_stages = 3 

74 pruned_configs.append(extra_config) 

75 extra_config2 = copy.deepcopy(extra_config) 

76 extra_config2.num_stages = 1 

77 pruned_configs.append(extra_config2) 

78 return pruned_configs 

79 

80 

81def log_softmax_tile_mode_for_non_inner(M, N, K, TILE_N, TILE_K): 

82 one_tile_k = TILE_K * max(TOTAL_CORE_NUM // M, 1) >= K 

83 one_tile_n = TILE_N >= N 

84 if one_tile_n and one_tile_k: 

85 return 0 

86 elif one_tile_n and not one_tile_k: 

87 return 1 

88 else: 

89 return 2 

90 

91 

92@libentry() 

93@libtuner( 

94 configs=[ 

95 triton.Config({"TILE_K": k, "TILE_N": 2**n}, num_stages=s, num_warps=1) 

96 for k in [1, 2, 4, 8] 

97 for n in range(10, 15, 1) 

98 for s in [1, 3] 

99 ], 

100 key=[ 

101 "N", 

102 "K", 

103 ], 

104 prune_configs_by={"early_config_prune": config_prune1}, 

105) 

106@triton.heuristics( 

107 values={ 

108 "TILE_MODE": lambda args: log_softmax_tile_mode_for_non_inner( 

109 args["M"], args["N"], args["K"], args["TILE_N"], args["TILE_K"] 

110 ), 

111 }, 

112) 

113@triton.jit 

114def log_softmax_kernel_non_inner( 

115 output_ptr, 

116 input_ptr, 

117 M, 

118 N, 

119 K, 

120 TILE_N: tl.constexpr, 

121 TILE_K: tl.constexpr, 

122 TILE_MODE: tl.constexpr, 

123): 

124 pid_m = tl.program_id(0) 

125 pid_k = tl.program_id(1) 

126 

127 p_k_num = tl.num_programs(axis=1) 

128 split_k = tl.cdiv(K, p_k_num) 

129 k_start = pid_k * split_k 

130 

131 log2e = 1.442695 

132 

133 if TILE_MODE == 0: 

134 n_offset = tl.arange(0, TILE_N) 

135 k_offset = pid_k * TILE_K + tl.arange(0, TILE_K) 

136 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :] 

137 mask = k_offset[None, :] < K 

138 input_ptrs = input_ptr + offset 

139 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32) 

140 m = inp - tl.max(inp, axis=0)[None, :] 

141 e = tl.exp(m) 

142 s = tl.sum(e, axis=0)[None, :] 

143 output = m - tl.log2(s) / log2e 

144 output_ptrs = output_ptr + offset 

145 tl.store(output_ptrs, output, mask=mask) 

146 elif TILE_MODE == 1: 

147 for k_idx in range(0, split_k, TILE_K): 

148 k_offset = k_start + k_idx + tl.arange(0, TILE_K) 

149 n_offset = tl.arange(0, TILE_N) 

150 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :] 

151 mask = k_offset[None, :] < K 

152 input_ptrs = input_ptr + offset 

153 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32) 

154 m = inp - tl.max(inp, axis=0)[None, :] 

155 e = tl.exp(m) 

156 s = tl.sum(e, axis=0)[None, :] 

157 output = m - tl.log2(s) / log2e 

158 output_ptrs = output_ptr + offset 

159 tl.store(output_ptrs, output, mask=mask) 

160 else: 

161 for k_idx in range(0, split_k, TILE_K): 

162 k_offset = k_start + k_idx + tl.arange(0, TILE_K) 

163 m = tl.full([TILE_N, TILE_K], value=float("-inf"), dtype=tl.float32) 

164 z = tl.full([TILE_N, TILE_K], value=0.0, dtype=tl.float32) 

165 

166 # specialization does not improve performance inn this example, as tested 

167 for start_n in range(0, N, TILE_N): 

168 n_offset = start_n + tl.arange(0, TILE_N) 

169 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :] 

170 mask = (n_offset[:, None] < N) & (k_offset[None, :] < K) 

171 inp = tl.load(input_ptr + offset, mask=mask, other=-float("inf")).to( 

172 tl.float32 

173 ) 

174 m_new = tl.maximum(m, inp) 

175 all_neg_inf = m_new == float("-inf") 

176 z = tl.where( 

177 all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new) 

178 ) 

179 m = m_new 

180 

181 m_reduced = tl.max(m, 0) # (TILE_K,) 

182 z = tl.sum(z * tl.exp(m - m_reduced[None, :]), 0) # (TILE_K, ) 

183 recip_z = 1.0 / z 

184 m = m_reduced 

185 

186 # specialization does not improve performance inn this example, as tested 

187 for start_n in range(0, N, TILE_N): 

188 n_offset = start_n + tl.arange(0, TILE_N) 

189 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :] 

190 mask = (n_offset[:, None] < N) & (k_offset[None, :] < K) 

191 inp = tl.load(input_ptr + offset, mask=mask, other=-float("inf")).to( 

192 tl.float32 

193 ) 

194 o = tl.exp(inp - m[None, :]) * recip_z[None, :] 

195 output = tl.log2(o) / log2e 

196 tl.store(output_ptr + offset, output, mask=mask) 

197 

198 

199def config_prune2(configs, named_args, **kwargs): 

200 M = named_args["M"] 

201 N = named_args["N"] 

202 input = named_args["input_ptr"] 

203 configs_map = {} 

204 # When N is less than MAX_C_MLU_SOFTMAX_FORWARD, no reduction loops 

205 for config in configs: 

206 kw = config.kwargs 

207 BLOCK_M, BLOCK_N, num_warps, num_stages = ( 

208 kw["BLOCK_M"], 

209 kw["BLOCK_N"], 

210 config.num_warps, 

211 config.num_stages, 

212 ) 

213 if N < MAX_N: 

214 config = copy.deepcopy(config) 

215 BLOCK_N = config.kwargs["BLOCK_N"] = N 

216 m_per_core = math.ceil(M / TOTAL_CORE_NUM) 

217 nram_usage = (2 * BLOCK_N + 1) * m_per_core * 4 

218 if nram_usage < MAX_NRAM_SIZE: 

219 BLOCK_M = config.kwargs["BLOCK_M"] = m_per_core 

220 num_stages = config.num_stages = 1 

221 key = (BLOCK_M, BLOCK_N, num_warps, num_stages) 

222 configs_map.setdefault(key, config) 

223 else: 

224 max_block_m_without_pipe = MAX_NRAM_SIZE // 4 // (2 * BLOCK_N + 1) 

225 BLOCK_M = config.kwargs["BLOCK_M"] = align(max_block_m_without_pipe) 

226 num_stages = config.num_stages = 1 

227 key = (BLOCK_M, BLOCK_N, num_warps, num_stages) 

228 configs_map.setdefault(key, config) 

229 

230 config = copy.deepcopy(config) 

231 max_block_m_without_pipe = MAX_NRAM_SIZE // 4 // (4 * BLOCK_N + 1) 

232 if input.dtype == torch.float32: 

233 max_block_m_without_pipe = MAX_NRAM_SIZE // 4 // (6 * BLOCK_N + 1) 

234 BLOCK_M = config.kwargs["BLOCK_M"] = align(max_block_m_without_pipe) 

235 num_stages = config.num_stages = 3 

236 key = (BLOCK_M, BLOCK_N, num_warps, num_stages) 

237 configs_map.setdefault(key, config) 

238 key = (BLOCK_M, BLOCK_N, num_warps, num_stages) 

239 # Only keep one config for the same key 

240 configs_map.setdefault(key, config) 

241 pruned_configs = [] 

242 for k, v in configs_map.items(): 

243 pruned_configs.append(v) 

244 # Add a heuristic config. 

245 extra_config = copy.deepcopy(pruned_configs[0]) 

246 extra_config.kwargs["BLOCK_M"] = 1 

247 extra_config.kwargs["BLOCK_N"] = N 

248 extra_config.num_warps = 1 

249 extra_config.num_stages = 3 

250 pruned_configs.append(extra_config) 

251 extra_config2 = copy.deepcopy(extra_config) 

252 extra_config2.num_stages = 1 

253 pruned_configs.append(extra_config2) 

254 return pruned_configs 

255 

256 

257def log_softmax_tile_mode_for_inner(M, N, BLOCK_M, BLOCK_N): 

258 one_tile_m = BLOCK_M * TOTAL_CORE_NUM >= M 

259 one_tile_n = BLOCK_N >= N 

260 if one_tile_n and one_tile_m: 

261 return 0 

262 elif one_tile_n and not one_tile_m: 

263 return 1 

264 else: 

265 return 2 

266 

267 

268@libentry() 

269@libtuner( 

270 configs=runtime.get_tuned_config("log_softmax"), 

271 key=[ 

272 "M", 

273 "N", 

274 ], 

275 prune_configs_by={"early_config_prune": config_prune2}, 

276) 

277@triton.heuristics( 

278 values={ 

279 "TILE_MODE": lambda args: log_softmax_tile_mode_for_inner( 

280 args["M"], args["N"], args["BLOCK_M"], args["BLOCK_N"] 

281 ), 

282 }, 

283) 

284@triton.jit 

285def log_softmax_kernel_inner( 

286 output_ptr, 

287 input_ptr, 

288 M, 

289 N, 

290 BLOCK_M: tl.constexpr, 

291 BLOCK_N: tl.constexpr, 

292 TILE_MODE: tl.constexpr, 

293): 

294 pid_m = tl.program_id(0) 

295 pnum = tl.num_programs(axis=0) 

296 split_m = tl.cdiv(M, pnum) 

297 m_start = pid_m * split_m 

298 

299 log2e = 1.442695 

300 

301 if TILE_MODE == 0: 

302 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

303 n_offset = tl.arange(0, BLOCK_N) 

304 offset = m_offset[:, None] * N + n_offset[None, :] 

305 mask = m_offset[:, None] < M 

306 input_ptrs = input_ptr + offset 

307 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32) 

308 row_minus_max = inp - tl.max(inp, axis=1)[:, None] 

309 numerator = tl.exp(row_minus_max) 

310 denominator = tl.sum(numerator, axis=1)[:, None] 

311 recip = 1.0 / denominator 

312 softmax_output = numerator * recip 

313 output = tl.log2(softmax_output) / log2e 

314 output_ptrs = output_ptr + offset 

315 tl.store(output_ptrs, output, mask=mask) 

316 elif TILE_MODE == 1: 

317 for m_idx in range(0, split_m, BLOCK_M): 

318 m_offset = m_start + m_idx + tl.arange(0, BLOCK_M) 

319 n_offset = tl.arange(0, BLOCK_N) 

320 offset = m_offset[:, None] * N + n_offset[None, :] 

321 mask = m_offset[:, None] < M 

322 input_ptrs = input_ptr + offset 

323 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32) 

324 trans_inp = tl.trans(inp) 

325 row_minus_max = trans_inp - tl.max(trans_inp, axis=0)[None, :] 

326 numerator = tl.exp(row_minus_max) 

327 denominator = tl.sum(numerator, axis=0)[None, :] 

328 recip = 1.0 / denominator 

329 softmax_output = tl.trans(numerator * recip) 

330 output = tl.log2(softmax_output) / log2e 

331 output_ptrs = output_ptr + offset 

332 tl.store(output_ptrs, output, mask=mask) 

333 else: 

334 for m_idx in range(0, split_m, BLOCK_M): 

335 m_offset = m_start + m_idx + tl.arange(0, BLOCK_M) 

336 block_max = tl.full( 

337 [BLOCK_M, BLOCK_N], value=float("-inf"), dtype=tl.float32 

338 ) 

339 block_sum = tl.full([BLOCK_M, BLOCK_N], value=0.0, dtype=tl.float32) 

340 # specialization does not improve performance inn this example, as tested 

341 for start_n in range(0, N, BLOCK_N): 

342 n_offset = start_n + tl.arange(0, BLOCK_N) 

343 offset = m_offset[:, None] * N + n_offset[None, :] 

344 mask = m_offset[:, None] < M and n_offset[None, :] < N 

345 inp = tl.load(input_ptr + offset, mask=mask, other=-float("inf")).to( 

346 tl.float32 

347 ) 

348 cur_max = tl.maximum(block_max, inp) 

349 all_neg_inf = cur_max == float("-inf") 

350 block_sum = tl.where( 

351 all_neg_inf, 

352 block_sum, 

353 block_sum * tl.exp(block_max - cur_max) + tl.exp(inp - cur_max), 

354 ) 

355 block_max = cur_max 

356 

357 trans_block_max = tl.trans(block_max) 

358 trans_block_sum = tl.trans(block_sum) 

359 max_reduced = tl.max(trans_block_max, 0) 

360 total_sum = tl.sum( 

361 trans_block_sum * tl.exp(trans_block_max - max_reduced[None, :]), 0 

362 ) 

363 recip_total_sum = 1.0 / total_sum 

364 total_max = max_reduced 

365 

366 for start_n in range(0, N, BLOCK_N): 

367 n_offset = start_n + tl.arange(0, BLOCK_N) 

368 offset = m_offset[:, None] * N + n_offset[None, :] 

369 mask = m_offset[:, None] < M and n_offset[None, :] < N 

370 inp = tl.load(input_ptr + offset, mask=mask, other=-float("inf")).to( 

371 tl.float32 

372 ) 

373 o = tl.exp(inp - total_max[:, None]) * recip_total_sum[:, None] 

374 output = tl.log2(o) / log2e 

375 tl.store(output_ptr + offset, output, mask=mask) 

376 

377 

378@triton.jit 

379def log_softmax_kernel_inner_k_partial_stats( 

380 x_ptr, 

381 max_buf_ptr, 

382 sum_buf_ptr, 

383 M, 

384 N, 

385 T, 

386 BLOCK_M: tl.constexpr, 

387 BLOCK_N: tl.constexpr, 

388): 

389 pnum = tl.num_programs(axis=0) 

390 pid = tl.program_id(0) 

391 total_blocks = (M // BLOCK_M) * T 

392 work_per_core = (total_blocks + pnum - 1) // pnum 

393 start = pid * work_per_core 

394 end = tl.minimum(start + work_per_core, total_blocks) 

395 

396 for task in range(start, end): 

397 row_id = task // T 

398 tile_id = task % T 

399 

400 offs_m = row_id * BLOCK_M + tl.arange(0, BLOCK_M) 

401 offs_n = tile_id * BLOCK_N + tl.arange(0, BLOCK_N) 

402 mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) 

403 

404 tile = tl.load( 

405 x_ptr + offs_m[:, None] * N + offs_n[None, :], 

406 mask=mask, 

407 other=-float("inf"), 

408 ).to(tl.float32) 

409 

410 tile_max = tl.max(tile, axis=1) 

411 all_neg_inf = tile_max == -float("inf") 

412 

413 tile_sum = tl.where( 

414 all_neg_inf, 

415 0.0, 

416 tl.sum(tl.exp(tile - tile_max[:, None]), axis=1), 

417 ) 

418 

419 tl.store(max_buf_ptr + offs_m * T + tile_id, tile_max, mask=(offs_m < M)) 

420 tl.store(sum_buf_ptr + offs_m * T + tile_id, tile_sum, mask=(offs_m < M)) 

421 

422 

423@triton.jit 

424def log_softmax_kernel_inner_k_merge_stats( 

425 max_buf_ptr, 

426 sum_buf_ptr, 

427 gmax_ptr, 

428 gsum_ptr, 

429 M: tl.constexpr, 

430 T: tl.constexpr, 

431 BLOCK_M: tl.constexpr, 

432): 

433 pid_m = tl.program_id(axis=0) 

434 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

435 mask_m = offs_m < M 

436 

437 tile_max = tl.load( 

438 max_buf_ptr + offs_m[:, None] * T + tl.arange(0, T)[None, :], 

439 mask=(offs_m[:, None] < M), 

440 other=-float("inf"), 

441 ) 

442 tile_sum = tl.load( 

443 sum_buf_ptr + offs_m[:, None] * T + tl.arange(0, T)[None, :], 

444 mask=(offs_m[:, None] < M), 

445 other=0.0, 

446 ).to(tl.float32) 

447 

448 gmax = tl.max(tile_max, axis=1) 

449 scale = tl.exp(tile_max - gmax[:, None]) 

450 scale = tl.where(gmax[:, None] == -float("inf"), 0.0, scale) 

451 gsum = tl.sum(tile_sum * scale, axis=1) 

452 

453 tl.store(gmax_ptr + offs_m, gmax, mask=mask_m) 

454 tl.store(gsum_ptr + offs_m, gsum, mask=mask_m) 

455 

456 

457@triton.jit 

458def log_softmax_kernel_inner_k_write_logsoftmax( 

459 x_ptr, 

460 y_ptr, 

461 gmax_ptr, 

462 gsum_ptr, 

463 M, 

464 N, 

465 T, 

466 BLOCK_M: tl.constexpr, 

467 BLOCK_N: tl.constexpr, 

468): 

469 log2e = 1.442695 

470 pnum = tl.num_programs(axis=0) 

471 pid = tl.program_id(0) 

472 total_blocks = (M // BLOCK_M) * T 

473 work_per_core = (total_blocks + pnum - 1) // pnum 

474 start = pid * work_per_core 

475 end = tl.minimum(start + work_per_core, total_blocks) 

476 

477 for task in range(start, end): 

478 row_id = task // T 

479 tile_id = task % T 

480 

481 offs_m = row_id * BLOCK_M + tl.arange(0, BLOCK_M) 

482 offs_n = tile_id * BLOCK_N + tl.arange(0, BLOCK_N) 

483 mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) 

484 

485 gmax = tl.load(gmax_ptr + offs_m, mask=(offs_m < M), other=-float("inf")).to( 

486 tl.float32 

487 ) 

488 gsum = tl.load(gsum_ptr + offs_m, mask=(offs_m < M), other=0.0).to(tl.float32) 

489 

490 tile = tl.load( 

491 x_ptr + offs_m[:, None] * N + offs_n[None, :], 

492 mask=mask, 

493 other=-float("inf"), 

494 ).to(tl.float32) 

495 

496 valid = gsum[:, None] > 0 

497 

498 o = tl.where( 

499 valid, 

500 tl.exp(tile - gmax[:, None]) / gsum[:, None], 

501 0.0, 

502 ) 

503 out = tl.log2(o) / log2e 

504 

505 tl.store(y_ptr + offs_m[:, None] * N + offs_n[None, :], out, mask=mask) 

506 

507 

508# ------------------------ backward ------------------------------- 

509 

510 

511def config_prune3(configs, named_args, **kwargs): 

512 M = named_args["M"] 

513 N = named_args["N"] 

514 K = named_args["K"] 

515 output = named_args["output_ptr"] 

516 configs_map = {} 

517 for config in configs: 

518 kw = config.kwargs 

519 TILE_K, TILE_N, num_warps, num_stages = ( 

520 kw["TILE_K"], 

521 kw["TILE_N"], 

522 config.num_warps, 

523 config.num_stages, 

524 ) 

525 if N < MAX_N: 

526 config = copy.deepcopy(config) 

527 TILE_N = config.kwargs["TILE_N"] = N 

528 k_per_core = math.ceil(K / max(TOTAL_CORE_NUM // M, 1)) 

529 nram_usage = (3 * TILE_N + 1) * k_per_core * 4 

530 if nram_usage < MAX_NRAM_SIZE: 

531 TILE_K = config.kwargs["TILE_K"] = k_per_core 

532 num_stages = config.num_stages = 1 

533 key = (TILE_K, TILE_N, num_warps, num_stages) 

534 configs_map.setdefault(key, config) 

535 else: 

536 max_tile_k_without_pipe = MAX_NRAM_SIZE // 4 // (3 * TILE_N + 1) 

537 TILE_K = config.kwargs["TILE_K"] = align(max_tile_k_without_pipe) 

538 num_stages = config.num_stages = 1 

539 key = (TILE_K, TILE_N, num_warps, num_stages) 

540 configs_map.setdefault(key, config) 

541 

542 config = copy.deepcopy(config) 

543 max_tile_k_without_pipe = MAX_NRAM_SIZE // 4 // (6 * TILE_N + 1) 

544 if output.dtype == torch.float32: 

545 max_tile_k_without_pipe = MAX_NRAM_SIZE // 4 // (7 * TILE_N + 1) 

546 TILE_K = config.kwargs["TILE_K"] = align(max_tile_k_without_pipe) 

547 num_stages = config.num_stages = 3 

548 key = (TILE_K, TILE_N, num_warps, num_stages) 

549 configs_map.setdefault(key, config) 

550 else: 

551 key = (TILE_K, TILE_N, num_warps, num_stages) 

552 configs_map.setdefault(key, config) 

553 pruned_configs = [] 

554 for k, v in configs_map.items(): 

555 pruned_configs.append(v) 

556 extra_config = copy.deepcopy(pruned_configs[0]) 

557 extra_config.kwargs["TILE_K"] = 1 

558 extra_config.kwargs["TILE_N"] = N 

559 extra_config.num_warps = 1 

560 extra_config.num_stages = 3 

561 pruned_configs.append(extra_config) 

562 extra_config2 = copy.deepcopy(extra_config) 

563 extra_config2.num_stages = 1 

564 pruned_configs.append(extra_config2) 

565 return pruned_configs 

566 

567 

568@libentry() 

569@libtuner( 

570 configs=[ 

571 triton.Config({"TILE_K": k, "TILE_N": 2**n}, num_stages=s, num_warps=1) 

572 for k in [1, 2, 4, 8] 

573 for n in range(10, 15, 1) 

574 for s in [1, 3] 

575 ], 

576 key=[ 

577 "N", 

578 "K", 

579 ], 

580 prune_configs_by={"early_config_prune": config_prune3}, 

581) 

582@triton.heuristics( 

583 values={ 

584 "TILE_MODE": lambda args: log_softmax_tile_mode_for_non_inner( 

585 args["M"], args["N"], args["K"], args["TILE_N"], args["TILE_K"] 

586 ), 

587 }, 

588) 

589@triton.jit 

590def log_softmax_backward_kernel_non_inner( 

591 output_ptr, 

592 out_grad_ptr, 

593 in_grad_ptr, 

594 M, 

595 N, 

596 K, 

597 TILE_N: tl.constexpr, 

598 TILE_K: tl.constexpr, 

599 TILE_MODE: tl.constexpr, 

600): 

601 pid_m = tl.program_id(0) 

602 pid_k = tl.program_id(1) 

603 

604 p_k_num = tl.num_programs(axis=1) 

605 split_k = tl.cdiv(K, p_k_num) 

606 k_start = pid_k * split_k 

607 

608 if TILE_MODE == 0: 

609 n_offset = tl.arange(0, TILE_N) 

610 k_offset = pid_k * TILE_K + tl.arange(0, TILE_K) 

611 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :] 

612 mask = k_offset[None, :] < K 

613 out_tile = tl.load(output_ptr + offset, mask=mask).to(tl.float32) 

614 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32) 

615 scale = tl.sum(out_grad_tile, axis=0) 

616 in_grad_tile = out_grad_tile - tl.exp(out_tile) * scale[None, :] 

617 tl.store(in_grad_ptr + offset, in_grad_tile, mask=mask) 

618 elif TILE_MODE == 1: 

619 for k_idx in range(0, split_k, TILE_K): 

620 k_offset = k_start + k_idx + tl.arange(0, TILE_K) 

621 n_offset = tl.arange(0, TILE_N) 

622 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :] 

623 mask = k_offset[None, :] < K 

624 out_tile = tl.load(output_ptr + offset, mask=mask).to(tl.float32) 

625 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32) 

626 scale = tl.sum(out_grad_tile, axis=0) 

627 in_grad_tile = out_grad_tile - tl.exp(out_tile) * scale[None, :] 

628 tl.store(in_grad_ptr + offset, in_grad_tile, mask=mask) 

629 else: 

630 for k_idx in range(0, split_k, TILE_K): 

631 k_offset = k_start + k_idx + tl.arange(0, TILE_K) 

632 scale = tl.zeros([TILE_N, TILE_K], dtype=tl.float32) 

633 # specialization does not improve performance inn this example, as tested 

634 for start_n in range(0, N, TILE_N): 

635 n_offset = start_n + tl.arange(0, TILE_N) 

636 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :] 

637 mask = (n_offset[:, None] < N) & (k_offset[None, :] < K) 

638 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32) 

639 scale += out_grad_tile 

640 scale = tl.sum(scale, axis=0) 

641 for start_n in range(0, N, TILE_N): 

642 n_offset = start_n + tl.arange(0, TILE_N) 

643 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :] 

644 mask = (n_offset[:, None] < N) & (k_offset[None, :] < K) 

645 out_tile = tl.load(output_ptr + offset, mask=mask).to(tl.float32) 

646 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32) 

647 in_grad_tile = out_grad_tile - tl.exp(out_tile) * scale[None, :] 

648 tl.store(in_grad_ptr + offset, in_grad_tile, mask=mask) 

649 

650 

651def nram_usage_for_backward_inner(bm, bn, tile_mode, num_stages, dtype): 

652 coef = 1 

653 if tile_mode == 0: 

654 if dtype == torch.float32: 

655 return 5 * bn * bm * 4 

656 else: 

657 return 4 * bn * bm * 4 

658 elif tile_mode == 1: 

659 if num_stages == 1: 

660 coef = 3 

661 else: 

662 if dtype == torch.float32: 

663 coef = 8 

664 else: 

665 coef = 6 

666 else: 

667 if num_stages == 1: 

668 coef = 4 

669 else: 

670 if dtype == torch.float32: 

671 coef = 11 

672 else: 

673 coef = 8 

674 return (coef * bn + 1) * bm * 4 

675 

676 

677def config_prune4(configs, named_args, **kwargs): 

678 M = named_args["M"] 

679 N = named_args["N"] 

680 output = named_args["output_ptr"] 

681 dtype = output.dtype 

682 m_per_core = math.ceil(M / TOTAL_CORE_NUM) 

683 # No need for any loop. 

684 if nram_usage_for_backward_inner(m_per_core, N, 0, 1, dtype) < MAX_NRAM_SIZE: 

685 config = copy.deepcopy(configs[0]) 

686 config.kwargs["BLOCK_M"] = m_per_core 

687 config.kwargs["BLOCK_N"] = N 

688 config.num_stages = 1 

689 return [config] 

690 align_num = 256 // 4 if dtype == torch.float32 else 256 // 2 

691 pruned_configs = [] 

692 for config in configs: 

693 kw = config.kwargs 

694 BLOCK_M, BLOCK_N, num_stages = ( 

695 kw["BLOCK_M"], 

696 kw["BLOCK_N"], 

697 config.num_stages, 

698 ) 

699 # Align the lowest dimension to 256B while loading/storing data. 

700 if BLOCK_N % align_num != 0: 

701 continue 

702 # nram usage shoule be smaller than MAX_NRAM_SIZE 

703 mode = log_softmax_tile_mode_for_inner(M, N, BLOCK_M, BLOCK_N) 

704 nram = nram_usage_for_backward_inner(BLOCK_M, BLOCK_N, mode, num_stages, dtype) 

705 if nram > MAX_NRAM_SIZE or nram < MAX_NRAM_SIZE // 2: 

706 continue 

707 pruned_configs.append(config) 

708 return pruned_configs 

709 

710 

711@libentry() 

712@libtuner( 

713 configs=[ 

714 triton.Config({"BLOCK_N": 64 * k, "BLOCK_M": 2**n}, num_stages=s, num_warps=1) 

715 for k in range(1, 17) 

716 for n in range(3, 10, 1) 

717 for s in [1, 3] 

718 ], 

719 key=[ 

720 "N", 

721 "M", 

722 ], 

723 prune_configs_by={"early_config_prune": config_prune4}, 

724) 

725@triton.heuristics( 

726 values={ 

727 "TILE_MODE": lambda args: log_softmax_tile_mode_for_inner( 

728 args["M"], args["N"], args["BLOCK_M"], args["BLOCK_N"] 

729 ), 

730 }, 

731) 

732@triton.jit 

733def log_softmax_backward_kernel_inner( 

734 output_ptr, 

735 out_grad_ptr, 

736 in_grad_ptr, 

737 M, 

738 N, 

739 BLOCK_M: tl.constexpr, 

740 BLOCK_N: tl.constexpr, 

741 TILE_MODE: tl.constexpr, 

742): 

743 pid_m = tl.program_id(0) 

744 pnum = tl.num_programs(axis=0) 

745 split_m = tl.cdiv(M, pnum) 

746 m_start = pid_m * split_m 

747 

748 if TILE_MODE == 0: 

749 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

750 n_offset = tl.arange(0, BLOCK_N) 

751 offset = m_offset[:, None] * N + n_offset[None, :] 

752 mask = m_offset[:, None] < M 

753 out_tile = tl.load(output_ptr + offset, mask=mask).to(tl.float32) 

754 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32) 

755 scale = tl.sum(out_grad_tile, 1) 

756 in_grad_tile = out_grad_tile - tl.exp(out_tile) * scale[:, None] 

757 tl.store(in_grad_ptr + offset, in_grad_tile, mask=mask) 

758 elif TILE_MODE == 1: 

759 for m_idx in range(0, split_m, BLOCK_M): 

760 m_offset = m_start + m_idx + tl.arange(0, BLOCK_M) 

761 n_offset = tl.arange(0, BLOCK_N) 

762 offset = m_offset[:, None] * N + n_offset[None, :] 

763 mask = m_offset[:, None] < M and n_offset[None, :] < N 

764 out_tile = tl.load(output_ptr + offset, mask=mask).to(tl.float32) 

765 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32) 

766 scale = tl.sum(out_grad_tile, 1) 

767 in_grad_tile = out_grad_tile - tl.exp(out_tile) * scale[:, None] 

768 tl.store(in_grad_ptr + offset, in_grad_tile, mask=mask) 

769 else: 

770 for m_idx in range(0, split_m, BLOCK_M): 

771 m_offset = m_start + m_idx + tl.arange(0, BLOCK_M) 

772 scale = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

773 for start_n in range(0, N, BLOCK_N): 

774 n_offset = start_n + tl.arange(0, BLOCK_N) 

775 offset = m_offset[:, None] * N + n_offset[None, :] 

776 mask = m_offset[:, None] < M and n_offset[None, :] < N 

777 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32) 

778 scale += out_grad_tile 

779 scale = tl.sum(scale, 1) 

780 for start_n in range(0, N, BLOCK_N): 

781 n_offset = start_n + tl.arange(0, BLOCK_N) 

782 offset = m_offset[:, None] * N + n_offset[None, :] 

783 mask = m_offset[:, None] < M and n_offset[None, :] < N 

784 out_tile = tl.load( 

785 output_ptr + offset, mask=mask, eviction_policy="evict_first" 

786 ).to(tl.float32) 

787 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32) 

788 in_grad_tile = out_grad_tile - tl.exp(out_tile) * scale[:, None] 

789 tl.store(in_grad_ptr + offset, in_grad_tile, mask=mask) 

790 

791 

792def log_softmax(self, dim, half_to_float=False): 

793 logger.debug("GEMS_CAMBRICON LOG_SOFTMAX") 

794 

795 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim" 

796 dim = dim % self.ndim 

797 M = 1 

798 N = self.shape[dim] 

799 for i in range(dim): 

800 M *= self.shape[i] 

801 inp = self.contiguous() 

802 if half_to_float: 

803 dtype = torch.float32 

804 else: 

805 dtype = self.dtype 

806 out = torch.empty_like(inp, dtype=dtype) 

807 K = inp.numel() // M // N 

808 

809 with torch_device_fn.device(inp.device): 

810 if K > 1: 

811 logger.debug("GEMS_CAMBRICON LOGSOFTMAX USE NON INNER") 

812 grid = lambda meta: (M, max(TOTAL_CORE_NUM // M, 1), 1) 

813 log_softmax_kernel_non_inner[grid]( 

814 out, 

815 inp, 

816 M, 

817 N, 

818 K, 

819 ) 

820 else: 

821 logger.debug("GEMS_CAMBRICON LOGSOFTMAX USE INNER") 

822 if M > TOTAL_CORE_NUM or N < 1024 * 8 * 8: 

823 log_softmax_kernel_inner[TOTAL_CORE_NUM, 1, 1]( 

824 out, 

825 inp, 

826 M, 

827 N, 

828 ) 

829 else: 

830 block_m = 1 

831 block_n = 8192 * 4 

832 if dtype is torch.float32: 

833 block_n = 8192 * 2 

834 # workspace 

835 T = (N + block_n - 1) // block_n 

836 max_buf = torch.empty((M, T), device=self.device, dtype=torch.float32) 

837 sum_buf = torch.empty((M, T), device=self.device, dtype=torch.float32) 

838 gmax = torch.empty((M,), device=self.device, dtype=torch.float32) 

839 gsum = torch.empty((M,), device=self.device, dtype=torch.float32) 

840 # kernel 1: per-tile stats 

841 log_softmax_kernel_inner_k_partial_stats[(TOTAL_CORE_NUM,)]( 

842 self, 

843 max_buf, 

844 sum_buf, 

845 M, 

846 N, 

847 T, 

848 BLOCK_M=block_m, 

849 BLOCK_N=block_n, 

850 bottleneck="simd", 

851 num_stages=3, 

852 ) 

853 # kernel 2: merge stats along N-tiles 

854 grid_merge = (triton.cdiv(M, block_m),) 

855 log_softmax_kernel_inner_k_merge_stats[grid_merge]( 

856 max_buf, sum_buf, gmax, gsum, M, T, BLOCK_M=block_m 

857 ) 

858 block_n = block_n // 2 

859 T = (N + block_n - 1) // block_n 

860 # kernel 3: write normalized outputs 

861 log_softmax_kernel_inner_k_write_logsoftmax[(TOTAL_CORE_NUM,)]( 

862 self, 

863 out, 

864 gmax, 

865 gsum, 

866 M, 

867 N, 

868 T, 

869 BLOCK_M=block_m, 

870 BLOCK_N=block_n, 

871 bottleneck="simd", 

872 num_stages=3, 

873 ) 

874 return out 

875 

876 

877def log_softmax_backward(grad_output, output, dim, input_dtype): 

878 logger.debug("GEMS_CAMBRICON LOG_SOFTMAX VJP") 

879 

880 assert dim >= -output.ndim and dim < output.ndim, "Invalid dim" 

881 dim = dim % output.ndim 

882 M = 1 

883 N = output.shape[dim] 

884 for i in range(dim): 

885 M *= output.shape[i] 

886 

887 grad_output = grad_output.contiguous() 

888 in_grad = torch.empty_like(output) 

889 K = output.numel() // M // N 

890 

891 with torch_device_fn.device(in_grad.device): 

892 if K > 1: 

893 logger.debug("GEMS_CAMBRICON LOG SOFTMAX VJP USE NON INNER") 

894 grid = lambda meta: (M, max(TOTAL_CORE_NUM // M, 1), 1) 

895 log_softmax_backward_kernel_non_inner[grid]( 

896 output, 

897 grad_output, 

898 in_grad, 

899 M, 

900 N, 

901 K, 

902 ) 

903 else: 

904 logger.debug("GEMS_CAMBRICON LOG SOFTMAX VJP USE INNER") 

905 grid = lambda meta: (triton.cdiv(M, meta["TILE_M"]), 1, 1) 

906 log_softmax_backward_kernel_inner[TOTAL_CORE_NUM, 1, 1]( 

907 output, 

908 grad_output, 

909 in_grad, 

910 M, 

911 N, 

912 ) 

913 return in_grad