Coverage for src/flag_gems/runtime/backend/_cambricon/ops/cummin.py: 0%

242 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-16 02:02 +0800

1import logging 

2import math 

3from typing import List, Tuple, Union 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12from flag_gems.utils.limits import get_dtype_max 

13 

14Tensor = torch.Tensor 

15logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

16 

17 

18@triton.jit 

19def tl_cummin(input, index, axis=0): 

20 return tl.associative_scan( 

21 (input, index), axis, tle.minimum_with_index_tie_break_right 

22 ) 

23 

24 

25@triton.jit 

26def tl_min_tie_break_right(input, index, axis=None, keep_dims=False): 

27 return tl.reduce( 

28 (input, index), 

29 axis, 

30 tle.minimum_with_index_tie_break_right, 

31 keep_dims=keep_dims, 

32 ) 

33 

34 

35@libentry() 

36@triton.jit(do_not_specialize=["n_elements"]) 

37def add_base_min_kernel( 

38 out, 

39 out_indices, 

40 partial_min, 

41 partial_min_indices, 

42 n_elements, 

43 BLOCK_SIZE: tl.constexpr, 

44): 

45 pid = tle.program_id(0) 

46 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

47 mask = offset < n_elements 

48 

49 out_ptrs = out + offset 

50 out_indices_ptrs = out_indices + offset 

51 out_vals = tl.load(out_ptrs, mask=mask) 

52 out_indices = tl.load(out_indices_ptrs, mask=mask) 

53 

54 if pid > 0: 

55 partial_min_ptrs = partial_min + pid - 1 

56 last_part_min_via_min = tl.load(partial_min_ptrs) 

57 partial_min_indices_ptrs = partial_min_indices + pid - 1 

58 last_part_min_index_via_min = tl.load(partial_min_indices_ptrs) 

59 

60 final_vals = tl.minimum(out_vals, last_part_min_via_min) 

61 final_indices = tl.where( 

62 out_vals <= last_part_min_via_min, out_indices, last_part_min_index_via_min 

63 ) 

64 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask) 

65 tl.store(out_indices_ptrs, final_indices, mask=mask) 

66 

67 

68@libentry() 

69@triton.jit(do_not_specialize=["n_elements"]) 

70def scan_part_min_kernel( 

71 inp, 

72 out, 

73 in_indices, 

74 out_indices, 

75 partial_min, 

76 partial_min_indices, 

77 n_elements, 

78 BLOCK_SIZE: tl.constexpr, 

79 NEED_PARTIAL: tl.constexpr, 

80 USE_OUT_INDICES: tl.constexpr, 

81): 

82 pid = tle.program_id(0) 

83 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

84 mask = offset < n_elements 

85 

86 max_value = get_dtype_max(inp.type.element_ty) 

87 inp_ptrs = inp + offset 

88 inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value) 

89 if ( 

90 tl.constexpr(inp_vals.dtype.is_int64()) 

91 or tl.constexpr(inp_vals.dtype.is_uint64()) 

92 ) or tl.constexpr(inp_vals.dtype.is_fp64()): 

93 inp_vals = inp_vals 

94 elif tl.constexpr(inp_vals.dtype.is_int()): 

95 inp_vals = inp_vals.to(tl.int32) 

96 else: 

97 inp_vals = inp_vals.to(tl.float32) 

98 if tl.constexpr(USE_OUT_INDICES): 

99 in_indices_ptrs = out_indices + offset 

100 in_indices_vals = tl.load(in_indices_ptrs, mask=mask) 

101 else: 

102 in_indices_vals = offset 

103 result, cummin_indices = tl_cummin(inp_vals, in_indices_vals, axis=0) 

104 

105 if tl.constexpr(NEED_PARTIAL): 

106 # tl.min do not support min_indices_tie_break_right 

107 part_min_via_min, part_min_indices_via_min = tl_min_tie_break_right( 

108 inp_vals, in_indices_vals, axis=0 

109 ) 

110 

111 out_ptrs = out + offset 

112 tl.store(out_ptrs, result, mask=mask) 

113 

114 out_indices_ptrs = out_indices + offset 

115 tl.store(out_indices_ptrs, cummin_indices, mask=mask) 

116 

117 if tl.constexpr(NEED_PARTIAL): 

118 partial_min_ptrs = partial_min + pid 

119 tl.store(partial_min_ptrs, part_min_via_min) 

120 

121 partial_min_indices_ptrs = partial_min_indices + pid 

122 tl.store(partial_min_indices_ptrs, part_min_indices_via_min) 

123 

124 

125def scan_then_fan_col(inp, out, out_indices, n_ele, dtype, use_out_indices=False): 

126 # TODO(all): tune on target board 

127 BLOCK_SIZE = 1024 

128 if n_ele <= 1024 * 4: 

129 BLOCK_SIZE = triton.next_power_of_2(n_ele) 

130 part_num = math.ceil(n_ele / BLOCK_SIZE) 

131 need_partial = True if part_num >= 2 else False 

132 if need_partial: 

133 partial_min = torch.empty(part_num, dtype=dtype, device=inp.device) 

134 partial_min_indices = torch.empty( 

135 part_num, dtype=torch.int64, device=inp.device 

136 ) 

137 else: 

138 partial_min = None 

139 partial_min_indices = None 

140 

141 grid = (part_num,) 

142 with torch_device_fn.device(inp.device): 

143 scan_part_min_kernel[grid]( 

144 inp, 

145 out, 

146 out_indices, 

147 out_indices, 

148 partial_min, 

149 partial_min_indices, 

150 n_ele, 

151 BLOCK_SIZE, 

152 need_partial, 

153 use_out_indices, 

154 ) 

155 

156 if part_num >= 2: 

157 scan_then_fan_col( 

158 partial_min, 

159 partial_min, 

160 partial_min_indices, 

161 part_num, 

162 dtype, 

163 use_out_indices=True, 

164 ) 

165 with torch_device_fn.device(inp.device): 

166 add_base_min_kernel[grid]( 

167 out, out_indices, partial_min, partial_min_indices, n_ele, BLOCK_SIZE 

168 ) 

169 

170 

171@libentry() 

172@triton.jit(do_not_specialize=["part_num"]) 

173def scan_part_min_abc_kernel( 

174 inp, 

175 out, 

176 in_indices, 

177 out_indices, 

178 partial_min, 

179 partial_min_indices, 

180 B, 

181 C, 

182 part_num, 

183 BLOCK_SIZE: tl.constexpr, 

184 NEED_PARTIAL: tl.constexpr, 

185 USE_OUT_INDICES: tl.constexpr, 

186): 

187 pid_a = tle.program_id(0) 

188 pid_b = tle.program_id(1) 

189 pid_c = tle.program_id(2) 

190 

191 a_idx = pid_a 

192 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

193 c_idx = pid_c 

194 

195 offset = a_idx * B * C + b_idx * C + c_idx 

196 base_part_offset = a_idx * part_num * C + c_idx 

197 part_offset = base_part_offset + pid_b * C 

198 

199 mask = b_idx < B 

200 inp_ptrs = inp + offset 

201 max_value = get_dtype_max(inp.type.element_ty) 

202 inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value) 

203 if ( 

204 tl.constexpr(inp_vals.dtype.is_int64()) 

205 or tl.constexpr(inp_vals.dtype.is_uint64()) 

206 ) or tl.constexpr(inp_vals.dtype.is_fp64()): 

207 inp_vals = inp_vals 

208 elif tl.constexpr(inp_vals.dtype.is_int()): 

209 inp_vals = inp_vals.to(tl.int32) 

210 else: 

211 inp_vals = inp_vals.to(tl.float32) 

212 if tl.constexpr(USE_OUT_INDICES): 

213 in_indices_ptrs = out_indices + offset 

214 in_indices_vals = tl.load(in_indices_ptrs, mask=mask) 

215 else: 

216 in_indices_vals = b_idx 

217 result, cummin_indices = tl_cummin(inp_vals, in_indices_vals, axis=0) 

218 

219 if tl.constexpr(NEED_PARTIAL): 

220 # tl.min do not support min_indices_tie_break_right 

221 part_min_via_min, part_min_indices_via_min = tl_min_tie_break_right( 

222 inp_vals, in_indices_vals, axis=0 

223 ) 

224 

225 out_ptrs = out + offset 

226 tl.store(out_ptrs, result, mask=mask) 

227 

228 out_indices_ptrs = out_indices + offset 

229 tl.store(out_indices_ptrs, cummin_indices, mask=mask) 

230 

231 if tl.constexpr(NEED_PARTIAL): 

232 partial_min_ptrs = partial_min + part_offset 

233 tl.store(partial_min_ptrs, part_min_via_min) 

234 

235 partial_min_indices_ptrs = partial_min_indices + part_offset 

236 tl.store(partial_min_indices_ptrs, part_min_indices_via_min) 

237 

238 

239@libentry() 

240@triton.jit(do_not_specialize=["part_num"]) 

241def add_base_min_abc_kernel( 

242 out, 

243 out_indices, 

244 partial_min, 

245 partial_min_indices, 

246 B, 

247 C, 

248 part_num, 

249 BLOCK_SIZE: tl.constexpr, 

250): 

251 pid_a = tle.program_id(0) 

252 pid_b = tle.program_id(1) 

253 pid_c = tle.program_id(2) 

254 

255 a_idx = pid_a 

256 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

257 c_idx = pid_c 

258 

259 base_offset = a_idx * B * C + c_idx 

260 offset = base_offset + b_idx * C 

261 base_part_offset = a_idx * part_num * C + c_idx 

262 last_part_offset = base_part_offset + (pid_b - 1) * C 

263 

264 mask = b_idx < B 

265 out_ptrs = out + offset 

266 out_vals = tl.load(out_ptrs, mask=mask) 

267 out_indices_ptrs = out_indices + offset 

268 out_indices = tl.load(out_indices_ptrs, mask=mask) 

269 

270 if pid_b > 0: 

271 partial_min_ptrs = partial_min + last_part_offset 

272 last_part_min_via_min = tl.load(partial_min_ptrs) 

273 partial_min_index_ptrs = partial_min_indices + last_part_offset 

274 last_part_min_index_via_min = tl.load(partial_min_index_ptrs) 

275 

276 final_vals = tl.minimum(out_vals, last_part_min_via_min) 

277 final_indices = tl.where( 

278 out_vals <= last_part_min_via_min, out_indices, last_part_min_index_via_min 

279 ) 

280 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask) 

281 tl.store(out_indices_ptrs, final_indices, mask=mask) 

282 

283 

284def scan_then_fan(inp, out, out_indices, A, B, C, dtype, use_out_indices=False): 

285 # TODO(all): tune on target board 

286 BLOCK_SIZE = 1024 

287 if B <= 1024 * 4: 

288 BLOCK_SIZE = triton.next_power_of_2(B) 

289 part_num = math.ceil(B / BLOCK_SIZE) 

290 need_partial = True if part_num >= 2 else False 

291 if need_partial: 

292 partial_min = torch.empty(A, part_num, C, dtype=dtype, device=inp.device) 

293 partial_min_indices = torch.empty( 

294 A, part_num, C, dtype=torch.int64, device=inp.device 

295 ) 

296 else: 

297 partial_min = None 

298 partial_min_indices = None 

299 

300 grid = (A, part_num, C) 

301 with torch_device_fn.device(inp.device): 

302 scan_part_min_abc_kernel[grid]( 

303 inp, 

304 out, 

305 out_indices, 

306 out_indices, 

307 partial_min, 

308 partial_min_indices, 

309 B, 

310 C, 

311 part_num, 

312 BLOCK_SIZE, 

313 need_partial, 

314 use_out_indices, 

315 ) 

316 

317 if part_num >= 2: 

318 scan_then_fan( 

319 partial_min, 

320 partial_min, 

321 partial_min_indices, 

322 A, 

323 part_num, 

324 C, 

325 dtype, 

326 use_out_indices=True, 

327 ) 

328 with torch_device_fn.device(inp.device): 

329 add_base_min_abc_kernel[grid]( 

330 out, 

331 out_indices, 

332 partial_min, 

333 partial_min_indices, 

334 B, 

335 C, 

336 part_num, 

337 BLOCK_SIZE, 

338 ) 

339 

340 

341@libentry() 

342@triton.jit() 

343def scan_part_min_abc_loop_kernel( 

344 inp, 

345 out, 

346 out_indices, 

347 B, 

348 C, 

349 loop_num, 

350 BLOCK_SIZE: tl.constexpr, 

351): 

352 pid_a = tle.program_id(0) 

353 pid_c = tle.program_id(1) 

354 

355 a_idx = pid_a 

356 c_idx = pid_c 

357 t_idx = tl.arange(0, BLOCK_SIZE) 

358 ac_offset = a_idx * B * C + c_idx 

359 

360 # init 

361 max_value = get_dtype_max(inp.type.element_ty) 

362 if tl.constexpr(inp.type.element_ty.is_fp16()) or tl.constexpr( 

363 inp.type.element_ty.is_bf16() 

364 ): 

365 compute_dtype = tl.float32 

366 elif tl.constexpr(inp.type.element_ty.is_int8()) or tl.constexpr( 

367 inp.type.element_ty.is_int16() 

368 ): 

369 compute_dtype = tl.int32 

370 else: 

371 compute_dtype = inp.type.element_ty 

372 

373 prev_min_val = tl.full([], max_value, dtype=compute_dtype) 

374 prev_min_val_idx = tl.full([], 0, dtype=tl.int64) 

375 last_mask = t_idx == (BLOCK_SIZE - 1) 

376 

377 for l_idx in tl.range(loop_num): 

378 b_idx = l_idx * BLOCK_SIZE + t_idx 

379 mask = b_idx < B 

380 offset = ac_offset + b_idx * C 

381 

382 inp_vals = tl.load(inp + offset, mask=mask, other=max_value) 

383 # Only promote if necessary 

384 if tl.constexpr(compute_dtype != inp.type.element_ty): 

385 vals = inp_vals.to(compute_dtype) 

386 else: 

387 vals = inp_vals 

388 idxs = b_idx 

389 

390 # cummin 

391 result, cummin_indices = tl_cummin(vals, idxs, axis=0) 

392 

393 # broadcast 

394 prev_min_val_b = tl.broadcast_to(prev_min_val, (BLOCK_SIZE,)) 

395 prev_min_val_idx_b = tl.broadcast_to(prev_min_val_idx, (BLOCK_SIZE,)) 

396 

397 # Handle NaN and tie-breaking logic 

398 if tl.constexpr(compute_dtype.is_floating()): 

399 # For floats: handle NaN propagation + tie-break right 

400 prev_is_nan = prev_min_val != prev_min_val 

401 result_is_nan = result != result 

402 prev_nan_mask = tl.broadcast_to(prev_is_nan, (BLOCK_SIZE,)) 

403 

404 use_result = result_is_nan | (~prev_nan_mask & (result <= prev_min_val_b)) 

405 else: 

406 # For integers: simple tie-break right 

407 use_result = result <= prev_min_val_b 

408 

409 final_vals = tl.where(use_result, result, prev_min_val_b) 

410 final_indices = tl.where(use_result, cummin_indices, prev_min_val_idx_b) 

411 

412 # update global min val and idx 

413 prev_min_val = tl.sum(tl.where(last_mask, final_vals, 0), axis=0) 

414 prev_min_val_idx = tl.sum(tl.where(last_mask, final_indices, 0), axis=0) 

415 

416 # store result 

417 tl.store(out + offset, final_vals.to(out.type.element_ty), mask=mask) 

418 tl.store(out_indices + offset, final_indices, mask=mask) 

419 

420 

421def scan_then_fan_loop(inp, out, out_indices, A, B, C, dtype): 

422 # TODO(all): tune on target board 

423 BLOCK_SIZE = 1024 

424 if B < 1024 * 4: 

425 BLOCK_SIZE = triton.next_power_of_2(B) 

426 loop_num = math.ceil(B / BLOCK_SIZE) 

427 

428 grid = (A, C) 

429 with torch_device_fn.device(inp.device): 

430 scan_part_min_abc_loop_kernel[grid]( 

431 inp, 

432 out, 

433 out_indices, 

434 B, 

435 C, 

436 loop_num, 

437 BLOCK_SIZE, 

438 ) 

439 

440 

441def cummin( 

442 input: Tensor, 

443 dim: int, 

444 *, 

445 out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None, 

446) -> torch.return_types.cummin: 

447 logger.debug("GEMS_CAMBRICON CUMMIN") 

448 assert dim >= -input.ndim and dim < input.ndim, "Invalid dim" 

449 shape = input.shape 

450 dim = dim % input.ndim 

451 M = 1 

452 N = shape[dim] 

453 for i in range(dim): 

454 M *= shape[i] 

455 input = input.contiguous() 

456 K = input.numel() // M // N 

457 

458 dtype = input.dtype 

459 if dtype is torch.bool: 

460 dtype = torch.int64 

461 out = torch.empty_like(input, dtype=dtype) 

462 out_indices = torch.empty_like(input, dtype=torch.int64) 

463 

464 compute_dtype = out.dtype 

465 if input.dtype == torch.float16 or input.dtype == torch.bfloat16: 

466 compute_dtype = torch.float32 

467 

468 if M == 1 and K == 1: 

469 scan_then_fan_col(input, out, out_indices, N, compute_dtype) 

470 elif M * K <= 16: 

471 scan_then_fan(input, out, out_indices, M, N, K, compute_dtype) 

472 else: 

473 scan_then_fan_loop(input, out, out_indices, M, N, K, compute_dtype) 

474 return out, out_indices