Coverage for src/flag_gems/runtime/backend/_mthreads/ops/argmin.py: 0%

121 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-25 02:48 +0800

1import builtins 

2import logging 

3import math 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12from flag_gems.utils.limits import get_dtype_max 

13 

14logger = logging.getLogger( 

15 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}' 

16) 

17 

18# Favor wider column tiles for long rows and more rows per block for tall shapes. 

19ARGMIN_REDUCTION_CONFIGS = [ 

20 triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=4, num_stages=1), 

21 triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=8, num_stages=1), 

22 triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=8, num_stages=2), 

23 triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_warps=8, num_stages=2), 

24 triton.Config({"BLOCK_M": 16, "BLOCK_N": 256}, num_warps=4, num_stages=2), 

25 triton.Config({"BLOCK_M": 16, "BLOCK_N": 512}, num_warps=8, num_stages=2), 

26 triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1), 

27] 

28 

29 

30def _prune_reduction_configs(configs, nargs, **meta): 

31 n = meta.get("N", nargs["N"]) 

32 if n <= 128: 

33 max_block_n = 128 

34 elif n <= 2048: 

35 max_block_n = 256 

36 else: 

37 max_block_n = 512 

38 return [cfg for cfg in configs if cfg.kwargs["BLOCK_N"] <= max_block_n] 

39 

40 

41def _flatten_dim(shape, dim): 

42 dim = dim % len(shape) 

43 n = shape[dim] 

44 inner = math.prod(shape[dim + 1 :]) if dim + 1 < len(shape) else 1 

45 outer = math.prod(shape[:dim]) if dim > 0 else 1 

46 return dim, n, inner, outer 

47 

48 

49@libentry() 

50@triton.jit 

51def argmin_kernel_1( 

52 inp, 

53 mid_value, 

54 mid_index, 

55 M, 

56 BLOCK_SIZE: tl.constexpr, 

57): 

58 pid = tle.program_id(0) 

59 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

60 mask = offset < M 

61 

62 max_value = get_dtype_max(inp.type.element_ty) 

63 inp_val = tl.load(inp + offset, mask=mask, other=max_value, cache_modifier=".cg") 

64 min_val, min_index = tl.min( 

65 inp_val, axis=0, return_indices=True, return_indices_tie_break_left=True 

66 ) 

67 tl.store(mid_value + pid, min_val) 

68 tl.store(mid_index + pid, min_index + pid * BLOCK_SIZE) 

69 

70 

71@libentry() 

72@triton.jit 

73def argmin_kernel_2( 

74 mid_value, 

75 mid_index, 

76 out, 

77 mid_size, 

78 BLOCK_MID: tl.constexpr, 

79): 

80 offset = tl.arange(0, BLOCK_MID) 

81 mask = offset < mid_size 

82 max_value = get_dtype_max(mid_value.type.element_ty) 

83 mid_val = tl.load(mid_value + offset, mask=mask, other=max_value) 

84 _, index_val = tl.min( 

85 mid_val, 

86 axis=0, 

87 return_indices=True, 

88 return_indices_tie_break_left=True, 

89 ) 

90 out_val = tl.load(mid_index + index_val) 

91 tl.store(out, out_val) 

92 

93 

94@libentry() 

95@triton.autotune( 

96 configs=ARGMIN_REDUCTION_CONFIGS, 

97 key=["M", "N"], 

98 warmup=8, 

99 rep=40, 

100 prune_configs_by={"early_config_prune": _prune_reduction_configs}, 

101) 

102@triton.jit 

103def argmin_kernel( 

104 inp, 

105 out_index, 

106 M, 

107 N, 

108 INNER, 

109 STRIDE_OUTER, 

110 STRIDE_REDUCE, 

111 BLOCK_M: tl.constexpr, 

112 BLOCK_N: tl.constexpr, 

113): 

114 pid_m = tle.program_id(0) 

115 rows = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

116 rows = rows.to(tl.int64) 

117 row_mask = rows < M 

118 

119 outer_idx = rows // INNER 

120 inner_idx = rows % INNER 

121 base_ptr = inp + outer_idx * STRIDE_OUTER + inner_idx 

122 

123 dtype = inp.type.element_ty 

124 acc_type = tl.float32 if dtype is tl.bfloat16 else dtype 

125 max_value = get_dtype_max(dtype) 

126 min_values = tl.full([BLOCK_M], dtype=acc_type, value=max_value) 

127 argmin_values = tl.full([BLOCK_M], dtype=tl.int32, value=0) 

128 

129 for start_n in range(0, N, BLOCK_N): 

130 n_offset = start_n + tl.arange(0, BLOCK_N) 

131 n_offset = n_offset.to(tl.int64) 

132 mask = row_mask[:, None] & (n_offset[None, :] < N) 

133 inp_ptrs = base_ptr[:, None] + n_offset[None, :] * STRIDE_REDUCE 

134 inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value, cache_modifier=".cg") 

135 local_min, local_argmin = tl.min( 

136 inp_vals, 

137 1, 

138 return_indices=True, 

139 return_indices_tie_break_left=True, 

140 ) 

141 local_argmin = local_argmin.to(tl.int32) 

142 update = local_min < min_values 

143 min_values = tl.where(update, local_min, min_values) 

144 argmin_values = tl.where( 

145 update, (start_n + local_argmin).to(tl.int32), argmin_values 

146 ) 

147 

148 out_index_ptrs = out_index + rows 

149 tl.store(out_index_ptrs, argmin_values, mask=row_mask) 

150 

151 

152def argmin(inp, dim=None, keepdim=False, *, dtype=None): 

153 logger.debug("GEMS_MTHREADS ARGMIN") 

154 if not inp.is_contiguous(): 

155 inp = inp.contiguous() 

156 

157 if dim is None: 

158 M = inp.numel() 

159 if dtype is None: 

160 dtype = inp.dtype 

161 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

162 block_size = builtins.min(block_size * 4, 4096, triton.next_power_of_2(M)) 

163 mid_size = triton.cdiv(M, block_size) 

164 block_mid = triton.next_power_of_2(mid_size) 

165 

166 mid_value = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

167 mid_index = torch.empty((mid_size,), dtype=torch.int64, device=inp.device) 

168 if keepdim: 

169 shape = list(inp.shape) 

170 for i in range(0, inp.dim()): 

171 shape[i] = 1 

172 out = torch.empty(shape, dtype=torch.int64, device=inp.device) 

173 else: 

174 out = torch.empty([], dtype=torch.int64, device=inp.device) 

175 

176 num_warps_block = builtins.min(8, max(1, block_size // 128)) 

177 num_warps_mid = builtins.min(8, max(1, block_mid // 128)) 

178 

179 with torch_device_fn.device(inp.device): 

180 argmin_kernel_1[(mid_size, 1, 1)]( 

181 inp, 

182 mid_value, 

183 mid_index, 

184 M, 

185 block_size, 

186 num_warps=num_warps_block, 

187 num_stages=2, 

188 ) 

189 argmin_kernel_2[(1, 1, 1)]( 

190 mid_value, 

191 mid_index, 

192 out, 

193 mid_size, 

194 block_mid, 

195 num_warps=num_warps_mid, 

196 num_stages=2, 

197 ) 

198 return out 

199 

200 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

201 dim = dim % inp.ndim 

202 

203 shape = list(inp.shape) 

204 dim, N, inner, outer = _flatten_dim(shape, dim) 

205 M = outer * inner 

206 stride = inp.stride() 

207 stride_reduce = stride[dim] 

208 stride_outer = stride_reduce * N 

209 

210 out_index = torch.empty((M,), dtype=torch.int32, device=inp.device) 

211 

212 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

213 with torch_device_fn.device(inp.device): 

214 argmin_kernel[grid]( 

215 inp, 

216 out_index, 

217 M, 

218 N, 

219 max(inner, 1), 

220 stride_outer, 

221 stride_reduce, 

222 ) 

223 

224 out_shape = shape.copy() 

225 out_shape[dim] = 1 

226 out_index = out_index.view(out_shape).to(torch.int64) 

227 if not keepdim: 

228 out_index = torch.squeeze(out_index, dim) 

229 return out_index 

230 

231 

232__all__ = ["argmin"]