Coverage for src/flag_gems/runtime/backend/_ascend/ops/argmin.py: 0%

101 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12from flag_gems.utils.limits import get_dtype_max 

13 

14logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

15 

16 

17@libentry() 

18@triton.jit 

19def argmin_kernel_1( 

20 inp, 

21 mid_value, 

22 mid_index, 

23 M, 

24 BLOCK_SIZE: tl.constexpr, 

25): 

26 pid = tle.program_id(0) 

27 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

28 inp_ptrs = inp + offset 

29 mask = offset < M 

30 

31 max_value = get_dtype_max(inp.type.element_ty) 

32 inp_val = tl.load(inp_ptrs, mask=mask, other=max_value) 

33 min_val, min_index = tl.min(inp_val, axis=0, return_indices=True) 

34 min_index = min_index + pid * BLOCK_SIZE 

35 mid_value_ptr = mid_value + pid 

36 min_index_ptr = mid_index + pid 

37 tl.store(mid_value_ptr, min_val) 

38 tl.store(min_index_ptr, min_index) 

39 

40 

41@libentry() 

42@triton.jit 

43def argmin_kernel_2( 

44 mid_value, 

45 mid_index, 

46 out, 

47 mid_size, 

48 BLOCK_MID: tl.constexpr, 

49): 

50 offset = tl.arange(0, BLOCK_MID) 

51 mid_ptrs = mid_value + offset 

52 mask = offset < mid_size 

53 max_value = get_dtype_max(mid_value.type.element_ty) 

54 mid_val = tl.load(mid_ptrs, mask=mask, other=max_value) 

55 index_val = tl.argmin(mid_val, axis=0) 

56 mid_index_ptrs = mid_index + index_val 

57 out_val = tl.load(mid_index_ptrs) 

58 tl.store(out, out_val) 

59 

60 

61@libentry() 

62@triton.heuristics(runtime.get_heuristic_config("argmin")) 

63@triton.jit 

64def argmin_kernel( 

65 inp, 

66 out_index, 

67 M, 

68 N, 

69 K, 

70 BLOCK_M: tl.constexpr, 

71 BLOCK_N: tl.constexpr, 

72): 

73 # set offset 

74 pid_m = tle.program_id(0) 

75 # pid_k = tle.program_id(1) 

76 for pid_k in range(K): 

77 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

78 

79 dtype = inp.type.element_ty 

80 acc_type = tl.float32 if dtype is tl.bfloat16 else dtype 

81 max_value = get_dtype_max(dtype) 

82 min_values = tl.full([BLOCK_M], dtype=acc_type, value=max_value) 

83 argmin_values = tl.full([BLOCK_M], dtype=tl.int64, value=0) 

84 for start_n in range(0, N, BLOCK_N): 

85 n_offset = start_n + tl.arange(0, BLOCK_N) 

86 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k 

87 mask = m_offset[:, None] < M and n_offset[None, :] < N 

88 inp_ptrs = inp + offset 

89 inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value) 

90 # tl.bfloat is promoted to tl.float32 by tl.min 

91 local_min, local_argmin = tl.min( 

92 inp_vals, 1, return_indices=True, return_indices_tie_break_left=True 

93 ) 

94 # if return indices is not supported, call a tl.argmin in addition 

95 # local_argmin = tl.argmin(inp_vals, 1) 

96 update = local_min < min_values 

97 min_values = tl.where(update, local_min, min_values) 

98 argmin_values = tl.where(update, start_n + local_argmin, argmin_values) 

99 

100 offset_index = m_offset * K + pid_k 

101 out_index_ptrs = out_index + offset_index 

102 mask1 = m_offset < M 

103 tl.store(out_index_ptrs, argmin_values, mask=mask1) 

104 

105 

106def argmin(inp, dim=None, keepdim=False, *, dtype=None): 

107 logger.debug("GEMS_ASCEND ARGMIN") 

108 if dim is None: 

109 M = inp.numel() 

110 if dtype is None: 

111 dtype = inp.dtype 

112 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

113 mid_size = triton.cdiv(M, block_size) 

114 block_mid = triton.next_power_of_2(mid_size) 

115 

116 mid_value = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

117 mid_index = torch.empty((mid_size,), dtype=torch.int64, device=inp.device) 

118 if keepdim: 

119 shape = list(inp.shape) 

120 for i in range(0, inp.dim()): 

121 shape[i] = 1 

122 out = torch.empty(shape, dtype=torch.int64, device=inp.device) 

123 else: 

124 out = torch.empty([], dtype=torch.int64, device=inp.device) 

125 

126 with torch_device_fn.device(inp.device): 

127 argmin_kernel_1[(mid_size, 1, 1)]( 

128 inp, 

129 mid_value, 

130 mid_index, 

131 M, 

132 block_size, 

133 ) 

134 argmin_kernel_2[(1, 1, 1)]( 

135 mid_value, 

136 mid_index, 

137 out, 

138 mid_size, 

139 block_mid, 

140 ) 

141 return out 

142 else: 

143 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

144 shape = inp.shape 

145 dim = dim % inp.ndim 

146 N = shape[dim] 

147 M = math.prod(shape[:dim]) 

148 K = inp.numel() // M // N 

149 

150 inp = inp.contiguous() 

151 

152 shape_list = list(shape) 

153 shape_list[dim] = 1 

154 out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device) 

155 if not keepdim: 

156 out_index = torch.squeeze(out_index, dim) 

157 

158 grid = lambda meta: ( 

159 triton.cdiv(M, meta["BLOCK_M"]), 

160 # K, 

161 ) 

162 with torch_device_fn.device(inp.device): 

163 argmin_kernel[grid]( 

164 inp, 

165 out_index, 

166 M, 

167 N, 

168 K, 

169 ) 

170 

171 return out_index