Coverage for src/flag_gems/runtime/backend/_mthreads/ops/min.py: 0%

123 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-18 02:36 +0800

1import builtins 

2import logging 

3import math 

4from collections import namedtuple 

5 

6import torch 

7import triton 

8import triton.language as tl 

9 

10from flag_gems.ops import min_dim as base_min_dim 

11from flag_gems.runtime import torch_device_fn 

12from flag_gems.utils import libentry 

13from flag_gems.utils import triton_lang_extension as tle 

14from flag_gems.utils.limits import get_dtype_max 

15 

16logger = logging.getLogger( 

17 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}' 

18) 

19 

20MinOut = namedtuple("min", ["values", "indices"]) 

21 

22# Expanded coverage favors smaller column tiles and more warps for tall shapes. 

23NAIVE_REDUCTION_CONFIGS = [ 

24 triton.Config({"BLOCK_M": 16, "BLOCK_N": 32}, num_warps=4, num_stages=1), 

25 triton.Config({"BLOCK_M": 16, "BLOCK_N": 64}, num_warps=2, num_stages=1), 

26 triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=8, num_stages=1), 

27 triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=16, num_stages=1), 

28 triton.Config({"BLOCK_M": 16, "BLOCK_N": 256}, num_warps=4, num_stages=1), 

29 triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_warps=8, num_stages=1), 

30 triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=16, num_stages=2), 

31] 

32 

33 

34def _prune_reduction_configs(configs, nargs, **meta): 

35 n = meta.get("N", None) 

36 if n is None: 

37 n = nargs["N"] 

38 max_block_n = 64 if n <= 128 else 256 

39 return [cfg for cfg in configs if cfg.kwargs["BLOCK_N"] <= max_block_n] 

40 

41 

42def _flatten_dim(shape, dim): 

43 dim = dim % len(shape) 

44 n = shape[dim] 

45 inner = math.prod(shape[dim + 1 :]) if dim + 1 < len(shape) else 1 

46 outer = math.prod(shape[:dim]) if dim > 0 else 1 

47 return dim, n, inner, outer 

48 

49 

50@libentry() 

51@triton.jit 

52def min_kernel_1( 

53 inp, 

54 mid, 

55 M, 

56 BLOCK_SIZE: tl.constexpr, 

57): 

58 pid = tle.program_id(0) 

59 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

60 inp_ptrs = inp + offset 

61 mask = offset < M 

62 max_value = get_dtype_max(inp.type.element_ty) 

63 inp_val = tl.load(inp_ptrs, mask=mask, other=max_value) 

64 min_val = tl.min(inp_val) 

65 mid_ptr = mid + pid 

66 tl.store(mid_ptr, min_val) 

67 

68 

69@libentry() 

70@triton.jit 

71def min_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr): 

72 offset = tl.arange(0, BLOCK_MID) 

73 mid_ptrs = mid + offset 

74 mask = offset < mid_size 

75 max_value = get_dtype_max(mid.type.element_ty) 

76 mid_val = tl.load(mid_ptrs, mask=mask, other=max_value) 

77 min_val = tl.min(mid_val) 

78 tl.store(out, min_val) 

79 

80 

81@libentry() 

82@triton.autotune( 

83 configs=NAIVE_REDUCTION_CONFIGS, 

84 key=["M", "N"], 

85 warmup=8, 

86 rep=40, 

87 prune_configs_by={"early_config_prune": _prune_reduction_configs}, 

88) 

89@triton.jit 

90def min_kernel( 

91 inp, 

92 out_value, 

93 out_index, 

94 M, 

95 N, 

96 INNER, 

97 STRIDE_OUTER, 

98 STRIDE_REDUCE, 

99 BLOCK_M: tl.constexpr, 

100 BLOCK_N: tl.constexpr, 

101): 

102 pid_m = tle.program_id(0) 

103 rows = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

104 rows = rows.to(tl.int64) 

105 row_mask = rows < M 

106 

107 outer_idx = rows // INNER 

108 inner_idx = rows % INNER 

109 base_ptr = inp + outer_idx * STRIDE_OUTER + inner_idx 

110 

111 dtype = inp.type.element_ty 

112 acc_type = tl.float32 if dtype is tl.bfloat16 else dtype 

113 max_value = get_dtype_max(dtype) 

114 min_values = tl.full([BLOCK_M], dtype=acc_type, value=max_value) 

115 argmin_values = tl.full([BLOCK_M], dtype=tl.int32, value=0) 

116 for start_n in range(0, N, BLOCK_N): 

117 n_offset = start_n + tl.arange(0, BLOCK_N) 

118 n_offset = n_offset.to(tl.int64) 

119 mask = row_mask[:, None] & (n_offset[None, :] < N) 

120 inp_ptrs = base_ptr[:, None] + n_offset[None, :] * STRIDE_REDUCE 

121 inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value, cache_modifier=".cg") 

122 local_min, local_argmin = tl.min(inp_vals, 1, return_indices=True) 

123 local_argmin = local_argmin.to(tl.int32) 

124 update = local_min < min_values 

125 min_values = tl.where(update, local_min, min_values) 

126 argmin_values = tl.where( 

127 update, (start_n + local_argmin).to(tl.int32), argmin_values 

128 ) 

129 

130 out_value_ptrs = out_value + rows 

131 out_index_ptrs = out_index + rows 

132 tl.store(out_value_ptrs, min_values, mask=row_mask) 

133 tl.store(out_index_ptrs, argmin_values, mask=row_mask) 

134 

135 

136def min(inp): 

137 logger.debug("GEMS_MTHREADS MIN") 

138 M = inp.numel() 

139 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

140 block_size = builtins.min(block_size * 4, 4096, triton.next_power_of_2(M)) 

141 mid_size = triton.cdiv(M, block_size) 

142 block_mid = triton.next_power_of_2(mid_size) 

143 

144 dtype = inp.dtype 

145 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

146 out = torch.empty([], dtype=dtype, device=inp.device) 

147 

148 num_warps_block = builtins.min(8, max(1, block_size // 128)) 

149 num_warps_mid = builtins.min(8, max(1, block_mid // 128)) 

150 

151 with torch_device_fn.device(inp.device): 

152 min_kernel_1[(mid_size, 1, 1)]( 

153 inp, mid, M, block_size, num_warps=num_warps_block, num_stages=2 

154 ) 

155 min_kernel_2[(1, 1, 1)]( 

156 mid, out, mid_size, block_mid, num_warps=num_warps_mid, num_stages=2 

157 ) 

158 return out 

159 

160 

161def min_dim(inp, dim=None, keepdim=False): 

162 logger.debug("GEMS_MTHREADS MIN DIM") 

163 assert dim is not None, "dim must be specified" 

164 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

165 dim = dim % inp.ndim 

166 

167 if not inp.is_contiguous(): 

168 # Fall back to the generic implementation (handles arbitrary strides). 

169 return base_min_dim(inp, dim=dim, keepdim=keepdim) 

170 

171 shape = list(inp.shape) 

172 dim, N, inner, outer = _flatten_dim(shape, dim) 

173 M = outer * inner 

174 stride = inp.stride() 

175 stride_reduce = stride[dim] 

176 stride_outer = stride_reduce * N 

177 

178 out_value = torch.empty((M,), dtype=inp.dtype, device=inp.device) 

179 out_index = torch.empty((M,), dtype=torch.int32, device=inp.device) 

180 

181 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

182 with torch_device_fn.device(inp.device): 

183 min_kernel[grid]( 

184 inp, 

185 out_value, 

186 out_index, 

187 M, 

188 N, 

189 max(inner, 1), 

190 stride_outer, 

191 stride_reduce, 

192 ) 

193 

194 out_shape = shape.copy() 

195 out_shape[dim] = 1 

196 out_value = out_value.view(out_shape) 

197 out_index = out_index.view(out_shape).to(torch.int64) 

198 if not keepdim: 

199 out_value = torch.squeeze(out_value, dim) 

200 out_index = torch.squeeze(out_index, dim) 

201 return MinOut(values=out_value, indices=out_index) 

202 

203 

204__all__ = ["min", "min_dim"]