Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/cummax.py: 0%

242 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1import logging 

2import math 

3from typing import List, Tuple, Union 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12from flag_gems.utils.limits import get_dtype_min 

13 

14Tensor = torch.Tensor 

15 

16logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

17 

18 

19@triton.jit 

20def tl_cummax(input, index, axis=0): 

21 return tl.associative_scan( 

22 (input, index), axis, tle.maximum_with_index_tie_break_right 

23 ) 

24 

25 

26@triton.jit 

27def tl_max_tie_break_right(input, index, axis=None, keep_dims=False): 

28 return tl.reduce( 

29 (input, index), 

30 axis, 

31 tle.maximum_with_index_tie_break_right, 

32 keep_dims=keep_dims, 

33 ) 

34 

35 

36@libentry() 

37@triton.jit(do_not_specialize=["n_elements"]) 

38def add_base_max_kernel( 

39 out, 

40 out_indices, 

41 partial_max, 

42 partial_max_indices, 

43 n_elements, 

44 BLOCK_SIZE: tl.constexpr, 

45): 

46 pid = tle.program_id(0) 

47 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

48 mask = offset < n_elements 

49 

50 out_ptrs = out + offset 

51 out_indices_ptrs = out_indices + offset 

52 out_vals = tl.load(out_ptrs, mask=mask) 

53 out_indices = tl.load(out_indices_ptrs, mask=mask) 

54 

55 if pid > 0: 

56 partial_max_ptrs = partial_max + pid - 1 

57 last_part_max_via_max = tl.load(partial_max_ptrs) 

58 partial_max_indices_ptrs = partial_max_indices + pid - 1 

59 last_part_max_index_via_max = tl.load(partial_max_indices_ptrs) 

60 

61 final_vals = tl.maximum(out_vals, last_part_max_via_max) 

62 final_indices = tl.where( 

63 out_vals >= last_part_max_via_max, out_indices, last_part_max_index_via_max 

64 ) 

65 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask) 

66 tl.store(out_indices_ptrs, final_indices, mask=mask) 

67 

68 

69@libentry() 

70@triton.jit(do_not_specialize=["n_elements"]) 

71def scan_part_max_kernel( 

72 inp, 

73 out, 

74 in_indices, 

75 out_indices, 

76 partial_max, 

77 partial_max_indices, 

78 n_elements, 

79 BLOCK_SIZE: tl.constexpr, 

80 NEED_PARTIAL: tl.constexpr, 

81 USE_OUT_INDICES: tl.constexpr, 

82): 

83 pid = tle.program_id(0) 

84 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

85 mask = offset < n_elements 

86 

87 min_value = get_dtype_min(inp.type.element_ty) 

88 inp_ptrs = inp + offset 

89 inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value) 

90 if ( 

91 tl.constexpr(inp_vals.dtype.is_int64()) 

92 or tl.constexpr(inp_vals.dtype.is_uint64()) 

93 ) or tl.constexpr(inp_vals.dtype.is_fp64()): 

94 inp_vals = inp_vals 

95 elif tl.constexpr(inp_vals.dtype.is_int()): 

96 inp_vals = inp_vals.to(tl.int32) 

97 else: 

98 inp_vals = inp_vals.to(tl.float32) 

99 if tl.constexpr(USE_OUT_INDICES): 

100 in_indices_ptrs = out_indices + offset 

101 in_indices_vals = tl.load(in_indices_ptrs, mask=mask) 

102 else: 

103 in_indices_vals = offset 

104 result, cummax_indices = tl_cummax(inp_vals, in_indices_vals, axis=0) 

105 

106 if tl.constexpr(NEED_PARTIAL): 

107 # tl.max do not support max_indices_tie_break_right 

108 part_max_via_max, part_max_indices_via_max = tl_max_tie_break_right( 

109 inp_vals, in_indices_vals, axis=0 

110 ) 

111 

112 out_ptrs = out + offset 

113 tl.store(out_ptrs, result, mask=mask) 

114 

115 out_indices_ptrs = out_indices + offset 

116 tl.store(out_indices_ptrs, cummax_indices, mask=mask) 

117 

118 if tl.constexpr(NEED_PARTIAL): 

119 partial_max_ptrs = partial_max + pid 

120 tl.store(partial_max_ptrs, part_max_via_max) 

121 

122 partial_max_indices_ptrs = partial_max_indices + pid 

123 tl.store(partial_max_indices_ptrs, part_max_indices_via_max) 

124 

125 

126def scan_then_fan_col(inp, out, out_indices, n_ele, dtype, use_out_indices=False): 

127 # TODO(all): tune on target board 

128 BLOCK_SIZE = 1024 

129 if n_ele <= 1024 * 4: 

130 BLOCK_SIZE = triton.next_power_of_2(n_ele) 

131 part_num = math.ceil(n_ele / BLOCK_SIZE) 

132 need_partial = True if part_num >= 2 else False 

133 if need_partial: 

134 partial_max = torch.empty(part_num, dtype=dtype, device=inp.device) 

135 partial_max_indices = torch.empty( 

136 part_num, dtype=torch.int64, device=inp.device 

137 ) 

138 else: 

139 partial_max = None 

140 partial_max_indices = None 

141 

142 grid = (part_num,) 

143 with torch_device_fn.device(inp.device): 

144 scan_part_max_kernel[grid]( 

145 inp, 

146 out, 

147 out_indices, 

148 out_indices, 

149 partial_max, 

150 partial_max_indices, 

151 n_ele, 

152 BLOCK_SIZE, 

153 need_partial, 

154 use_out_indices, 

155 ) 

156 

157 if part_num >= 2: 

158 scan_then_fan_col( 

159 partial_max, 

160 partial_max, 

161 partial_max_indices, 

162 part_num, 

163 dtype, 

164 use_out_indices=True, 

165 ) 

166 with torch_device_fn.device(inp.device): 

167 add_base_max_kernel[grid]( 

168 out, out_indices, partial_max, partial_max_indices, n_ele, BLOCK_SIZE 

169 ) 

170 

171 

172@libentry() 

173@triton.jit(do_not_specialize=["part_num"]) 

174def scan_part_max_abc_kernel( 

175 inp, 

176 out, 

177 in_indices, 

178 out_indices, 

179 partial_max, 

180 partial_max_indices, 

181 B, 

182 C, 

183 part_num, 

184 BLOCK_SIZE: tl.constexpr, 

185 NEED_PARTIAL: tl.constexpr, 

186 USE_OUT_INDICES: tl.constexpr, 

187): 

188 pid_a = tle.program_id(0) 

189 pid_b = tle.program_id(1) 

190 pid_c = tle.program_id(2) 

191 

192 a_idx = pid_a 

193 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

194 c_idx = pid_c 

195 

196 offset = a_idx * B * C + b_idx * C + c_idx 

197 base_part_offset = a_idx * part_num * C + c_idx 

198 part_offset = base_part_offset + pid_b * C 

199 

200 mask = b_idx < B 

201 inp_ptrs = inp + offset 

202 min_value = get_dtype_min(inp.type.element_ty) 

203 inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value) 

204 if ( 

205 tl.constexpr(inp_vals.dtype.is_int64()) 

206 or tl.constexpr(inp_vals.dtype.is_uint64()) 

207 ) or tl.constexpr(inp_vals.dtype.is_fp64()): 

208 inp_vals = inp_vals 

209 elif tl.constexpr(inp_vals.dtype.is_int()): 

210 inp_vals = inp_vals.to(tl.int32) 

211 else: 

212 inp_vals = inp_vals.to(tl.float32) 

213 if tl.constexpr(USE_OUT_INDICES): 

214 in_indices_ptrs = out_indices + offset 

215 in_indices_vals = tl.load(in_indices_ptrs, mask=mask) 

216 else: 

217 in_indices_vals = b_idx 

218 result, cummax_indices = tl_cummax(inp_vals, in_indices_vals, axis=0) 

219 

220 if tl.constexpr(NEED_PARTIAL): 

221 # tl.max do not support max_indices_tie_break_right 

222 part_max_via_max, part_max_indices_via_max = tl_max_tie_break_right( 

223 inp_vals, in_indices_vals, axis=0 

224 ) 

225 

226 out_ptrs = out + offset 

227 tl.store(out_ptrs, result, mask=mask) 

228 

229 out_indices_ptrs = out_indices + offset 

230 tl.store(out_indices_ptrs, cummax_indices, mask=mask) 

231 

232 if tl.constexpr(NEED_PARTIAL): 

233 partial_max_ptrs = partial_max + part_offset 

234 tl.store(partial_max_ptrs, part_max_via_max) 

235 

236 partial_max_indices_ptrs = partial_max_indices + part_offset 

237 tl.store(partial_max_indices_ptrs, part_max_indices_via_max) 

238 

239 

240@libentry() 

241@triton.jit(do_not_specialize=["part_num"]) 

242def add_base_max_abc_kernel( 

243 out, 

244 out_indices, 

245 partial_max, 

246 partial_max_indices, 

247 B, 

248 C, 

249 part_num, 

250 BLOCK_SIZE: tl.constexpr, 

251): 

252 pid_a = tle.program_id(0) 

253 pid_b = tle.program_id(1) 

254 pid_c = tle.program_id(2) 

255 

256 a_idx = pid_a 

257 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

258 c_idx = pid_c 

259 

260 base_offset = a_idx * B * C + c_idx 

261 offset = base_offset + b_idx * C 

262 base_part_offset = a_idx * part_num * C + c_idx 

263 last_part_offset = base_part_offset + (pid_b - 1) * C 

264 

265 mask = b_idx < B 

266 out_ptrs = out + offset 

267 out_vals = tl.load(out_ptrs, mask=mask) 

268 out_indices_ptrs = out_indices + offset 

269 out_indices = tl.load(out_indices_ptrs, mask=mask) 

270 

271 if pid_b > 0: 

272 partial_max_ptrs = partial_max + last_part_offset 

273 last_part_max_via_max = tl.load(partial_max_ptrs) 

274 partial_max_index_ptrs = partial_max_indices + last_part_offset 

275 last_part_max_index_via_max = tl.load(partial_max_index_ptrs) 

276 

277 final_vals = tl.maximum(out_vals, last_part_max_via_max) 

278 final_indices = tl.where( 

279 out_vals >= last_part_max_via_max, out_indices, last_part_max_index_via_max 

280 ) 

281 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask) 

282 tl.store(out_indices_ptrs, final_indices, mask=mask) 

283 

284 

285def scan_then_fan(inp, out, out_indices, A, B, C, dtype, use_out_indices=False): 

286 # TODO(all): tune on target board 

287 BLOCK_SIZE = 1024 

288 if B <= 1024 * 4: 

289 BLOCK_SIZE = triton.next_power_of_2(B) 

290 part_num = math.ceil(B / BLOCK_SIZE) 

291 need_partial = True if part_num >= 2 else False 

292 if need_partial: 

293 partial_max = torch.empty(A, part_num, C, dtype=dtype, device=inp.device) 

294 partial_max_indices = torch.empty( 

295 A, part_num, C, dtype=torch.int64, device=inp.device 

296 ) 

297 else: 

298 partial_max = None 

299 partial_max_indices = None 

300 

301 grid = (A, part_num, C) 

302 with torch_device_fn.device(inp.device): 

303 scan_part_max_abc_kernel[grid]( 

304 inp, 

305 out, 

306 out_indices, 

307 out_indices, 

308 partial_max, 

309 partial_max_indices, 

310 B, 

311 C, 

312 part_num, 

313 BLOCK_SIZE, 

314 need_partial, 

315 use_out_indices, 

316 ) 

317 

318 if part_num >= 2: 

319 scan_then_fan( 

320 partial_max, 

321 partial_max, 

322 partial_max_indices, 

323 A, 

324 part_num, 

325 C, 

326 dtype, 

327 use_out_indices=True, 

328 ) 

329 with torch_device_fn.device(inp.device): 

330 add_base_max_abc_kernel[grid]( 

331 out, 

332 out_indices, 

333 partial_max, 

334 partial_max_indices, 

335 B, 

336 C, 

337 part_num, 

338 BLOCK_SIZE, 

339 ) 

340 

341 

342@libentry() 

343@triton.jit() 

344def scan_part_max_abc_loop_kernel( 

345 inp, 

346 out, 

347 out_indices, 

348 B, 

349 C, 

350 loop_num, 

351 BLOCK_SIZE: tl.constexpr, 

352): 

353 pid_a = tle.program_id(0) 

354 pid_c = tle.program_id(1) 

355 

356 a_idx = pid_a 

357 c_idx = pid_c 

358 t_idx = tl.arange(0, BLOCK_SIZE) 

359 ac_offset = a_idx * B * C + c_idx 

360 

361 # init, promote low precision types 

362 min_value = get_dtype_min(inp.type.element_ty) 

363 if tl.constexpr(inp.type.element_ty.is_fp16()) or tl.constexpr( 

364 inp.type.element_ty.is_bf16() 

365 ): 

366 compute_dtype = tl.float32 

367 elif tl.constexpr(inp.type.element_ty.is_int8()) or tl.constexpr( 

368 inp.type.element_ty.is_int16() 

369 ): 

370 compute_dtype = tl.int32 

371 else: 

372 compute_dtype = inp.type.element_ty 

373 

374 prev_max_val = tl.full([], min_value, dtype=compute_dtype) 

375 prev_max_val_idx = tl.full([], 0, dtype=tl.int64) 

376 last_mask = t_idx == (BLOCK_SIZE - 1) 

377 

378 for l_idx in tl.range(loop_num): 

379 b_idx = l_idx * BLOCK_SIZE + t_idx 

380 mask = b_idx < B 

381 offset = ac_offset + b_idx * C 

382 

383 inp_vals = tl.load(inp + offset, mask=mask, other=min_value) 

384 # Only promote if necessary 

385 if tl.constexpr(compute_dtype != inp.type.element_ty): 

386 vals = inp_vals.to(compute_dtype) 

387 else: 

388 vals = inp_vals 

389 idxs = b_idx 

390 

391 # cummax 

392 result, cummax_indices = tl_cummax(vals, idxs, axis=0) 

393 

394 # broadcast 

395 prev_max_val_b = tl.broadcast_to(prev_max_val, (BLOCK_SIZE,)) 

396 prev_max_val_idx_b = tl.broadcast_to(prev_max_val_idx, (BLOCK_SIZE,)) 

397 

398 # Handle NaN and tie-breaking logic 

399 if tl.constexpr(compute_dtype.is_floating()): 

400 # For floats: handle NaN propagation + tie-break right 

401 prev_is_nan = prev_max_val != prev_max_val 

402 result_is_nan = result != result 

403 prev_nan_mask = tl.broadcast_to(prev_is_nan, (BLOCK_SIZE,)) 

404 

405 use_result = result_is_nan | (~prev_nan_mask & (result >= prev_max_val_b)) 

406 else: 

407 # For integers: simple tie-break right 

408 use_result = result >= prev_max_val_b 

409 

410 final_vals = tl.where(use_result, result, prev_max_val_b) 

411 final_indices = tl.where(use_result, cummax_indices, prev_max_val_idx_b) 

412 

413 # update global max val and idx 

414 prev_max_val = tl.sum(tl.where(last_mask, final_vals, 0), axis=0) 

415 prev_max_val_idx = tl.sum(tl.where(last_mask, final_indices, 0), axis=0) 

416 

417 # store result 

418 tl.store(out + offset, final_vals.to(out.type.element_ty), mask=mask) 

419 tl.store(out_indices + offset, final_indices, mask=mask) 

420 

421 

422def scan_then_fan_loop(inp, out, out_indices, A, B, C, dtype): 

423 # TODO(all): tune on target board 

424 BLOCK_SIZE = 1024 

425 if B < 1024 * 4: 

426 BLOCK_SIZE = triton.next_power_of_2(B) 

427 loop_num = math.ceil(B / BLOCK_SIZE) 

428 

429 grid = (A, C) 

430 with torch_device_fn.device(inp.device): 

431 scan_part_max_abc_loop_kernel[grid]( 

432 inp, 

433 out, 

434 out_indices, 

435 B, 

436 C, 

437 loop_num, 

438 BLOCK_SIZE, 

439 is_use_mask_zero=True, 

440 ) 

441 

442 

443def cummax( 

444 input: Tensor, 

445 dim: int, 

446 *, 

447 out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None, 

448) -> torch.return_types.cummax: 

449 logger.debug("GEMS cummax") 

450 assert dim >= -input.ndim and dim < input.ndim, "Invalid dim" 

451 shape = input.shape 

452 dim = dim % input.ndim 

453 M = 1 

454 N = shape[dim] 

455 for i in range(dim): 

456 M *= shape[i] 

457 input = input.contiguous() 

458 K = input.numel() // M // N 

459 

460 dtype = input.dtype 

461 if dtype is torch.bool: 

462 dtype = torch.int64 

463 out = torch.empty_like(input, dtype=dtype) 

464 out_indices = torch.empty_like(input, dtype=torch.int64) 

465 

466 compute_dtype = out.dtype 

467 if input.dtype == torch.float16 or input.dtype == torch.bfloat16: 

468 compute_dtype = torch.float32 

469 

470 if M == 1 and K == 1: 

471 scan_then_fan_col(input, out, out_indices, N, compute_dtype) 

472 elif M * K <= 16: 

473 scan_then_fan(input, out, out_indices, M, N, K, compute_dtype) 

474 else: 

475 scan_then_fan_loop(input, out, out_indices, M, N, K, compute_dtype) 

476 return out, out_indices