Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/argmin.py: 0%

96 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-10 02:30 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

14torch_dtype_to_tl_dtype_and_max_value = { 

15 torch.int16: (tl.int16, torch.iinfo(torch.int16).max), 

16 torch.int32: (tl.int32, torch.iinfo(torch.int32).max), 

17 torch.float16: (tl.float16, torch.finfo(torch.float16).max), 

18 torch.float32: (tl.float32, torch.finfo(torch.float32).max), 

19 torch.bfloat16: (tl.float32, torch.finfo(torch.float32).max), 

20} 

21 

22 

23@libentry() 

24@triton.jit 

25def argmin_kernel_1( 

26 inp, 

27 mid_value, 

28 mid_index, 

29 M, 

30 BLOCK_SIZE: tl.constexpr, 

31 dtype_max_value: tl.constexpr, 

32): 

33 pid = tle.program_id(0) 

34 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

35 inp_ptrs = inp + offset 

36 mask = offset < M 

37 inp_val = tl.load(inp_ptrs, mask=mask, other=dtype_max_value) 

38 min_val, min_index = tl.min(inp_val, axis=0, return_indices=True) 

39 min_index = min_index + pid * BLOCK_SIZE 

40 mid_value_ptr = mid_value + pid 

41 min_index_ptr = mid_index + pid 

42 tl.store(mid_value_ptr, min_val) 

43 tl.store(min_index_ptr, min_index) 

44 

45 

46@libentry() 

47@triton.jit 

48def argmin_kernel_2( 

49 mid_value, 

50 mid_index, 

51 out, 

52 mid_size, 

53 BLOCK_MID: tl.constexpr, 

54 dtype_max_value: tl.constexpr, 

55): 

56 offset = tl.arange(0, BLOCK_MID) 

57 mid_ptrs = mid_value + offset 

58 mask = offset < mid_size 

59 mid_val = tl.load(mid_ptrs, mask=mask, other=dtype_max_value) 

60 index_val = tl.argmin(mid_val, axis=0) 

61 mid_index_ptrs = mid_index + index_val 

62 out_val = tl.load(mid_index_ptrs) 

63 tl.store(out, out_val) 

64 

65 

66def heur_block_n(args): 

67 return min(4096, triton.next_power_of_2(args["N"])) 

68 

69 

70@libentry() 

71@triton.heuristics(runtime.get_heuristic_config("argmin")) 

72@triton.jit 

73def argmin_kernel( 

74 inp, 

75 out_index, 

76 M, 

77 N, 

78 K, 

79 tl_dtype: tl.constexpr, 

80 dtype_max_value: tl.constexpr, 

81 BLOCK_M: tl.constexpr, 

82 BLOCK_N: tl.constexpr, 

83): 

84 # set offset 

85 pid_m = tle.program_id(0) 

86 pid_k = tle.program_id(1) 

87 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

88 

89 # min_values = tl.full([BLOCK_M], dtype=tl.float32, value=float("inf")) 

90 if tl_dtype is tl.int16: 

91 tl_dtype = tl.int32 

92 n_offset = tl.arange(0, BLOCK_N) 

93 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k 

94 offset_index = m_offset * K + pid_k 

95 # set mask 

96 mask1 = m_offset < M 

97 mask = m_offset[:, None] < M and n_offset[None, :] < N 

98 inp_ptrs = inp + offset 

99 inp_vals = tl.load(inp_ptrs, mask=mask, other=-float("inf")) 

100 # inp_vals = tl.where(mask, inp_vals, -float("inf")) 

101 _, result_index = tl.min(inp_vals, axis=1, return_indices=True) 

102 

103 out_index_ptrs = out_index + offset_index 

104 

105 tl.store(out_index_ptrs, result_index, mask=mask1) 

106 

107 

108def argmin(inp, dim=None, keepdim=False, *, dtype=None): 

109 logger.debug("GEMS argmin") 

110 if dim is None: 

111 M = inp.numel() 

112 if dtype is None: 

113 dtype = inp.dtype 

114 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

115 mid_size = triton.cdiv(M, block_size) 

116 block_mid = triton.next_power_of_2(mid_size) 

117 

118 mid_value = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

119 mid_index = torch.empty((mid_size,), dtype=torch.int64, device=inp.device) 

120 if keepdim: 

121 shape = list(inp.shape) 

122 for i in range(0, inp.dim()): 

123 shape[i] = 1 

124 out = torch.empty(shape, dtype=torch.int64, device=inp.device) 

125 else: 

126 out = torch.empty([], dtype=torch.int64, device=inp.device) 

127 

128 tl_dtype, dtype_max_value = torch_dtype_to_tl_dtype_and_max_value[inp.dtype] 

129 with torch_device_fn.device(inp.device): 

130 argmin_kernel_1[(mid_size, 1, 1)]( 

131 inp, 

132 mid_value, 

133 mid_index, 

134 M, 

135 block_size, 

136 dtype_max_value, 

137 ) 

138 argmin_kernel_2[(1, 1, 1)]( 

139 mid_value, 

140 mid_index, 

141 out, 

142 mid_size, 

143 block_mid, 

144 dtype_max_value, 

145 ) 

146 return out 

147 else: 

148 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

149 shape = inp.shape 

150 dim = dim % inp.ndim 

151 N = shape[dim] 

152 M = math.prod(shape[:dim]) 

153 K = inp.numel() // M // N 

154 

155 inp = inp.contiguous() 

156 

157 shape_list = list(shape) 

158 shape_list[dim] = 1 

159 out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device) 

160 if not keepdim: 

161 out_index = torch.squeeze(out_index, dim) 

162 

163 tl_dtype, dtype_max_value = torch_dtype_to_tl_dtype_and_max_value[inp.dtype] 

164 

165 grid = lambda meta: ( 

166 triton.cdiv(M, meta["BLOCK_M"]), 

167 K, 

168 ) 

169 with torch_device_fn.device(inp.device): 

170 argmin_kernel[grid]( 

171 inp, 

172 out_index, 

173 M, 

174 N, 

175 K, 

176 tl_dtype, 

177 dtype_max_value, 

178 ) 

179 

180 return out_index