Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/cummin.py: 0%

243 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-22 16:54 +0800

1import logging 

2import math 

3from collections import namedtuple 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12from flag_gems.utils.limits import get_dtype_max 

13 

14# from typing import List, Tuple, Union 

15 

16 

17CumminResult = namedtuple("CumminResult", ["values", "indices"]) 

18 

19Tensor = torch.Tensor 

20 

21logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

22 

23 

24@triton.jit 

25def tl_cummin(input, index, axis=0): 

26 return tl.associative_scan( 

27 (input, index), axis, tle.minimum_with_index_tie_break_right 

28 ) 

29 

30 

31@triton.jit 

32def tl_min_tie_break_right(input, index, axis=None, keep_dims=False): 

33 return tl.reduce( 

34 (input, index), 

35 axis, 

36 tle.minimum_with_index_tie_break_right, 

37 keep_dims=keep_dims, 

38 ) 

39 

40 

41@libentry() 

42@triton.jit(do_not_specialize=["n_elements"]) 

43def add_base_min_kernel( 

44 out, 

45 out_indices, 

46 partial_min, 

47 partial_min_indices, 

48 n_elements, 

49 BLOCK_SIZE: tl.constexpr, 

50): 

51 pid = tle.program_id(0) 

52 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

53 mask = offset < n_elements 

54 

55 out_ptrs = out + offset 

56 out_indices_ptrs = out_indices + offset 

57 out_vals = tl.load(out_ptrs, mask=mask) 

58 out_indices = tl.load(out_indices_ptrs, mask=mask) 

59 

60 if pid > 0: 

61 partial_min_ptrs = partial_min + pid - 1 

62 last_part_min_via_min = tl.load(partial_min_ptrs) 

63 partial_min_indices_ptrs = partial_min_indices + pid - 1 

64 last_part_min_index_via_min = tl.load(partial_min_indices_ptrs) 

65 

66 final_vals = tl.minimum(out_vals, last_part_min_via_min) 

67 final_indices = tl.where( 

68 out_vals <= last_part_min_via_min, out_indices, last_part_min_index_via_min 

69 ) 

70 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask) 

71 tl.store(out_indices_ptrs, final_indices, mask=mask) 

72 

73 

74@libentry() 

75@triton.jit(do_not_specialize=["n_elements"]) 

76def scan_part_min_kernel( 

77 inp, 

78 out, 

79 in_indices, 

80 out_indices, 

81 partial_min, 

82 partial_min_indices, 

83 n_elements, 

84 BLOCK_SIZE: tl.constexpr, 

85 NEED_PARTIAL: tl.constexpr, 

86 USE_OUT_INDICES: tl.constexpr, 

87): 

88 pid = tle.program_id(0) 

89 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

90 mask = offset < n_elements 

91 

92 max_value = get_dtype_max(inp.type.element_ty) 

93 inp_ptrs = inp + offset 

94 inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value) 

95 if ( 

96 tl.constexpr(inp_vals.dtype.is_int64()) 

97 or tl.constexpr(inp_vals.dtype.is_uint64()) 

98 ) or tl.constexpr(inp_vals.dtype.is_fp64()): 

99 inp_vals = inp_vals 

100 elif tl.constexpr(inp_vals.dtype.is_int()): 

101 inp_vals = inp_vals.to(tl.int32) 

102 else: 

103 inp_vals = inp_vals.to(tl.float32) 

104 if tl.constexpr(USE_OUT_INDICES): 

105 in_indices_ptrs = out_indices + offset 

106 in_indices_vals = tl.load(in_indices_ptrs, mask=mask) 

107 else: 

108 in_indices_vals = offset 

109 result, cummin_indices = tl_cummin(inp_vals, in_indices_vals, axis=0) 

110 

111 if tl.constexpr(NEED_PARTIAL): 

112 # tl.min do not support min_indices_tie_break_right 

113 part_min_via_min, part_min_indices_via_min = tl_min_tie_break_right( 

114 inp_vals, in_indices_vals, axis=0 

115 ) 

116 

117 out_ptrs = out + offset 

118 tl.store(out_ptrs, result, mask=mask) 

119 

120 out_indices_ptrs = out_indices + offset 

121 tl.store(out_indices_ptrs, cummin_indices, mask=mask) 

122 

123 if tl.constexpr(NEED_PARTIAL): 

124 partial_min_ptrs = partial_min + pid 

125 tl.store(partial_min_ptrs, part_min_via_min) 

126 

127 partial_min_indices_ptrs = partial_min_indices + pid 

128 tl.store(partial_min_indices_ptrs, part_min_indices_via_min) 

129 

130 

131def scan_then_fan_col(inp, out, out_indices, n_ele, dtype, use_out_indices=False): 

132 # TODO(all): tune on target board 

133 BLOCK_SIZE = 1024 

134 if n_ele <= 1024 * 4: 

135 BLOCK_SIZE = triton.next_power_of_2(n_ele) 

136 part_num = math.ceil(n_ele / BLOCK_SIZE) 

137 need_partial = True if part_num >= 2 else False 

138 if need_partial: 

139 partial_min = torch.empty(part_num, dtype=dtype, device=inp.device) 

140 partial_min_indices = torch.empty( 

141 part_num, dtype=torch.int64, device=inp.device 

142 ) 

143 else: 

144 partial_min = None 

145 partial_min_indices = None 

146 

147 grid = (part_num,) 

148 with torch_device_fn.device(inp.device): 

149 scan_part_min_kernel[grid]( 

150 inp, 

151 out, 

152 out_indices, 

153 out_indices, 

154 partial_min, 

155 partial_min_indices, 

156 n_ele, 

157 BLOCK_SIZE, 

158 need_partial, 

159 use_out_indices, 

160 ) 

161 

162 if part_num >= 2: 

163 scan_then_fan_col( 

164 partial_min, 

165 partial_min, 

166 partial_min_indices, 

167 part_num, 

168 dtype, 

169 use_out_indices=True, 

170 ) 

171 with torch_device_fn.device(inp.device): 

172 add_base_min_kernel[grid]( 

173 out, out_indices, partial_min, partial_min_indices, n_ele, BLOCK_SIZE 

174 ) 

175 

176 

177@libentry() 

178@triton.jit(do_not_specialize=["part_num"]) 

179def scan_part_min_abc_kernel( 

180 inp, 

181 out, 

182 in_indices, 

183 out_indices, 

184 partial_min, 

185 partial_min_indices, 

186 B, 

187 C, 

188 part_num, 

189 BLOCK_SIZE: tl.constexpr, 

190 NEED_PARTIAL: tl.constexpr, 

191 USE_OUT_INDICES: tl.constexpr, 

192): 

193 pid_a = tle.program_id(0) 

194 pid_b = tle.program_id(1) 

195 pid_c = tle.program_id(2) 

196 

197 a_idx = pid_a 

198 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

199 c_idx = pid_c 

200 

201 offset = a_idx * B * C + b_idx * C + c_idx 

202 base_part_offset = a_idx * part_num * C + c_idx 

203 part_offset = base_part_offset + pid_b * C 

204 

205 mask = b_idx < B 

206 inp_ptrs = inp + offset 

207 max_value = get_dtype_max(inp.type.element_ty) 

208 inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value) 

209 if ( 

210 tl.constexpr(inp_vals.dtype.is_int64()) 

211 or tl.constexpr(inp_vals.dtype.is_uint64()) 

212 ) or tl.constexpr(inp_vals.dtype.is_fp64()): 

213 inp_vals = inp_vals 

214 elif tl.constexpr(inp_vals.dtype.is_int()): 

215 inp_vals = inp_vals.to(tl.int32) 

216 else: 

217 inp_vals = inp_vals.to(tl.float32) 

218 if tl.constexpr(USE_OUT_INDICES): 

219 in_indices_ptrs = out_indices + offset 

220 in_indices_vals = tl.load(in_indices_ptrs, mask=mask) 

221 else: 

222 in_indices_vals = b_idx 

223 result, cummin_indices = tl_cummin(inp_vals, in_indices_vals, axis=0) 

224 

225 if tl.constexpr(NEED_PARTIAL): 

226 # tl.min do not support min_indices_tie_break_right 

227 part_min_via_min, part_min_indices_via_min = tl_min_tie_break_right( 

228 inp_vals, in_indices_vals, axis=0 

229 ) 

230 

231 out_ptrs = out + offset 

232 tl.store(out_ptrs, result, mask=mask) 

233 

234 out_indices_ptrs = out_indices + offset 

235 tl.store(out_indices_ptrs, cummin_indices, mask=mask) 

236 

237 if tl.constexpr(NEED_PARTIAL): 

238 partial_min_ptrs = partial_min + part_offset 

239 tl.store(partial_min_ptrs, part_min_via_min) 

240 

241 partial_min_indices_ptrs = partial_min_indices + part_offset 

242 tl.store(partial_min_indices_ptrs, part_min_indices_via_min) 

243 

244 

245@libentry() 

246@triton.jit(do_not_specialize=["part_num"]) 

247def add_base_min_abc_kernel( 

248 out, 

249 out_indices, 

250 partial_min, 

251 partial_min_indices, 

252 B, 

253 C, 

254 part_num, 

255 BLOCK_SIZE: tl.constexpr, 

256): 

257 pid_a = tle.program_id(0) 

258 pid_b = tle.program_id(1) 

259 pid_c = tle.program_id(2) 

260 

261 a_idx = pid_a 

262 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

263 c_idx = pid_c 

264 

265 base_offset = a_idx * B * C + c_idx 

266 offset = base_offset + b_idx * C 

267 base_part_offset = a_idx * part_num * C + c_idx 

268 last_part_offset = base_part_offset + (pid_b - 1) * C 

269 

270 mask = b_idx < B 

271 out_ptrs = out + offset 

272 out_vals = tl.load(out_ptrs, mask=mask) 

273 out_indices_ptrs = out_indices + offset 

274 out_indices = tl.load(out_indices_ptrs, mask=mask) 

275 

276 if pid_b > 0: 

277 partial_min_ptrs = partial_min + last_part_offset 

278 last_part_min_via_min = tl.load(partial_min_ptrs) 

279 partial_min_index_ptrs = partial_min_indices + last_part_offset 

280 last_part_min_index_via_min = tl.load(partial_min_index_ptrs) 

281 

282 final_vals = tl.minimum(out_vals, last_part_min_via_min) 

283 final_indices = tl.where( 

284 out_vals <= last_part_min_via_min, out_indices, last_part_min_index_via_min 

285 ) 

286 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask) 

287 tl.store(out_indices_ptrs, final_indices, mask=mask) 

288 

289 

290def scan_then_fan(inp, out, out_indices, A, B, C, dtype, use_out_indices=False): 

291 # TODO(all): tune on target board 

292 BLOCK_SIZE = 1024 

293 if B <= 1024 * 4: 

294 BLOCK_SIZE = triton.next_power_of_2(B) 

295 part_num = math.ceil(B / BLOCK_SIZE) 

296 need_partial = True if part_num >= 2 else False 

297 if need_partial: 

298 partial_min = torch.empty(A, part_num, C, dtype=dtype, device=inp.device) 

299 partial_min_indices = torch.empty( 

300 A, part_num, C, dtype=torch.int64, device=inp.device 

301 ) 

302 else: 

303 partial_min = None 

304 partial_min_indices = None 

305 

306 grid = (A, part_num, C) 

307 with torch_device_fn.device(inp.device): 

308 scan_part_min_abc_kernel[grid]( 

309 inp, 

310 out, 

311 out_indices, 

312 out_indices, 

313 partial_min, 

314 partial_min_indices, 

315 B, 

316 C, 

317 part_num, 

318 BLOCK_SIZE, 

319 need_partial, 

320 use_out_indices, 

321 ) 

322 

323 if part_num >= 2: 

324 scan_then_fan( 

325 partial_min, 

326 partial_min, 

327 partial_min_indices, 

328 A, 

329 part_num, 

330 C, 

331 dtype, 

332 use_out_indices=True, 

333 ) 

334 with torch_device_fn.device(inp.device): 

335 add_base_min_abc_kernel[grid]( 

336 out, 

337 out_indices, 

338 partial_min, 

339 partial_min_indices, 

340 B, 

341 C, 

342 part_num, 

343 BLOCK_SIZE, 

344 ) 

345 

346 

347@libentry() 

348@triton.jit() 

349def scan_part_min_abc_loop_kernel( 

350 inp, 

351 out, 

352 out_indices, 

353 B, 

354 C, 

355 loop_num, 

356 BLOCK_SIZE: tl.constexpr, 

357): 

358 pid_a = tle.program_id(0) 

359 pid_c = tle.program_id(1) 

360 

361 a_idx = pid_a 

362 c_idx = pid_c 

363 t_idx = tl.arange(0, BLOCK_SIZE) 

364 ac_offset = a_idx * B * C + c_idx 

365 

366 # init 

367 max_value = get_dtype_max(inp.type.element_ty) 

368 if tl.constexpr(inp.type.element_ty.is_fp16()) or tl.constexpr( 

369 inp.type.element_ty.is_bf16() 

370 ): 

371 compute_dtype = tl.float32 

372 elif tl.constexpr(inp.type.element_ty.is_int8()) or tl.constexpr( 

373 inp.type.element_ty.is_int16() 

374 ): 

375 compute_dtype = tl.int32 

376 else: 

377 compute_dtype = inp.type.element_ty 

378 

379 prev_min_val = tl.full([], max_value, dtype=compute_dtype) 

380 prev_min_val_idx = tl.full([], 0, dtype=tl.int64) 

381 last_mask = t_idx == (BLOCK_SIZE - 1) 

382 

383 for l_idx in tl.range(loop_num): 

384 b_idx = l_idx * BLOCK_SIZE + t_idx 

385 mask = b_idx < B 

386 offset = ac_offset + b_idx * C 

387 

388 inp_vals = tl.load(inp + offset, mask=mask, other=max_value) 

389 # Only promote if necessary 

390 if tl.constexpr(compute_dtype != inp.type.element_ty): 

391 vals = inp_vals.to(compute_dtype) 

392 else: 

393 vals = inp_vals 

394 idxs = b_idx 

395 

396 # cummin 

397 result, cummin_indices = tl_cummin(vals, idxs, axis=0) 

398 

399 # broadcast 

400 prev_min_val_b = tl.broadcast_to(prev_min_val, (BLOCK_SIZE,)) 

401 prev_min_val_idx_b = tl.broadcast_to(prev_min_val_idx, (BLOCK_SIZE,)) 

402 

403 # Handle NaN and tie-breaking logic 

404 if tl.constexpr(compute_dtype.is_floating()): 

405 # For floats: handle NaN propagation + tie-break right 

406 prev_is_nan = prev_min_val != prev_min_val 

407 result_is_nan = result != result 

408 prev_nan_mask = tl.broadcast_to(prev_is_nan, (BLOCK_SIZE,)) 

409 

410 use_result = result_is_nan | (~prev_nan_mask & (result <= prev_min_val_b)) 

411 else: 

412 # For integers: simple tie-break right 

413 use_result = result <= prev_min_val_b 

414 

415 final_vals = tl.where(use_result, result, prev_min_val_b) 

416 final_indices = tl.where(use_result, cummin_indices, prev_min_val_idx_b) 

417 

418 # update global min val and idx 

419 prev_min_val = tl.sum(tl.where(last_mask, final_vals, 0), axis=0) 

420 prev_min_val_idx = tl.sum(tl.where(last_mask, final_indices, 0), axis=0) 

421 

422 # store result 

423 tl.store(out + offset, final_vals.to(out.type.element_ty), mask=mask) 

424 tl.store(out_indices + offset, final_indices, mask=mask) 

425 

426 

427def scan_then_fan_loop(inp, out, out_indices, A, B, C, dtype): 

428 # TODO(all): tune on target board 

429 BLOCK_SIZE = 1024 

430 if B < 1024 * 4: 

431 BLOCK_SIZE = triton.next_power_of_2(B) 

432 loop_num = math.ceil(B / BLOCK_SIZE) 

433 

434 grid = (A, C) 

435 with torch_device_fn.device(inp.device): 

436 scan_part_min_abc_loop_kernel[grid]( 

437 inp, 

438 out, 

439 out_indices, 

440 B, 

441 C, 

442 loop_num, 

443 BLOCK_SIZE, 

444 is_use_mask_zero=True, 

445 ) 

446 

447 

448# def cummin( 

449# input: Tensor, 

450# dim: int, 

451# *, 

452# out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None, 

453# ) -> torch.return_types.cummin: 

454def cummin(input, dim=1): 

455 logger.debug("GEMS cummin") 

456 assert dim >= -input.ndim and dim < input.ndim, "Invalid dim" 

457 shape = input.shape 

458 dim = dim % input.ndim 

459 M = 1 

460 N = shape[dim] 

461 for i in range(dim): 

462 M *= shape[i] 

463 input = input.contiguous() 

464 K = input.numel() // M // N 

465 

466 dtype = input.dtype 

467 if dtype is torch.bool: 

468 dtype = torch.int64 

469 out = torch.empty_like(input, dtype=dtype) 

470 out_indices = torch.empty_like(input, dtype=torch.int64) 

471 

472 compute_dtype = out.dtype 

473 if input.dtype == torch.float16 or input.dtype == torch.bfloat16: 

474 compute_dtype = torch.float32 

475 

476 if M == 1 and K == 1: 

477 scan_then_fan_col(input, out, out_indices, N, compute_dtype) 

478 elif M * K <= 16: 

479 scan_then_fan(input, out, out_indices, M, N, K, compute_dtype) 

480 else: 

481 scan_then_fan_loop(input, out, out_indices, M, N, K, compute_dtype) 

482 # return out, out_indices 

483 return CumminResult(out, out_indices)