Coverage for src/flag_gems/experimental_ops/amin.py: 0%

135 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-11 02:28 +0800

1from functools import reduce 

2from operator import mul 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8 

9@triton.jit 

10def amin_reduce_last_kernel( 

11 x_ptr, 

12 out_ptr, 

13 M, # number of rows (outer size) 

14 K, # reduction length (last-axis size) 

15 stride_xm, 

16 stride_xk, 

17 init, # identity value for min (same dtype as x) 

18 BLOCK_SIZE: tl.constexpr, 

19): 

20 pid = tl.program_id(0) 

21 mask_m = pid < M 

22 acc = init 

23 k = 0 

24 while k < K: 

25 offs = k + tl.arange(0, BLOCK_SIZE) 

26 mask = mask_m & (offs < K) 

27 vals = tl.load( 

28 x_ptr + pid * stride_xm + offs * stride_xk, mask=mask, other=init 

29 ) 

30 block_min = tl.min(vals, axis=0) 

31 acc = tl.minimum(acc, block_min) 

32 k += BLOCK_SIZE 

33 tl.store(out_ptr + pid, acc, mask=mask_m) 

34 

35 

36def _prod(seq): 

37 return int(reduce(mul, seq, 1)) 

38 

39 

40def _parse_dims(dim, ndim): 

41 if dim is None: 

42 return list(range(ndim)) 

43 if isinstance(dim, (list, tuple)): 

44 dims = [int(d) for d in dim] 

45 else: 

46 dims = [int(dim)] 

47 # normalize negatives and remove duplicates preserving order 

48 seen = set() 

49 norm = [] 

50 for d in dims: 

51 dd = d if d >= 0 else d + ndim 

52 if dd < 0 or dd >= ndim: 

53 raise IndexError("Dimension out of range in amin") 

54 if dd not in seen: 

55 norm.append(dd) 

56 seen.add(dd) 

57 return norm 

58 

59 

60def _amin_impl( 

61 x: torch.Tensor, dim=None, keepdim: bool = False, out: torch.Tensor = None 

62): 

63 if not x.is_cuda: 

64 raise RuntimeError("Triton amin kernel requires CUDA tensors") 

65 ndim = x.ndim 

66 reduce_dims = _parse_dims(dim, ndim) 

67 if len(reduce_dims) == 0: 

68 # No reduction dims specified, return input (or copy into out) 

69 if out is None: 

70 return x.clone() 

71 if out.numel() != x.numel(): 

72 raise RuntimeError( 

73 "out tensor has incorrect number of elements for amin with empty dims" 

74 ) 

75 out.copy_(x) 

76 return out 

77 

78 # Determine output shape 

79 input_sizes = list(x.size()) 

80 keep_sizes = input_sizes.copy() 

81 for d in reduce_dims: 

82 keep_sizes[d] = 1 

83 non_reduce_dims = [i for i in range(ndim) if i not in reduce_dims] 

84 non_reduce_sizes = [input_sizes[i] for i in non_reduce_dims] 

85 

86 final_shape = keep_sizes if keepdim else non_reduce_sizes 

87 

88 # Prepare permutation: move non-reduced dims first, reduced dims last 

89 perm = non_reduce_dims + reduce_dims 

90 x_perm = x.permute(perm) 

91 x_perm = x_perm.contiguous() 

92 

93 # Flatten into [M, K] 

94 M = _prod(non_reduce_sizes) if len(non_reduce_sizes) > 0 else 1 

95 K = _prod([input_sizes[i] for i in reduce_dims]) if len(reduce_dims) > 0 else 1 

96 

97 if K == 0: 

98 raise RuntimeError( 

99 "amin reduction has an empty dimension (no identity for min)" 

100 ) 

101 

102 x_2d = x_perm.view(M, K) 

103 

104 # Identity/initial value for min based on dtype 

105 dt = x.dtype 

106 if dt.is_floating_point: 

107 init_val = float("inf") 

108 elif dt == torch.bool: 

109 init_val = True 

110 else: 

111 # integer types 

112 info = torch.iinfo(dt) 

113 init_val = int(info.max) 

114 

115 # Prepare output row vector of length M 

116 if out is None: 

117 out_row = torch.empty((M,), dtype=x.dtype, device=x.device) 

118 out_target = None 

119 else: 

120 # Ensure out shape matches final_shape 

121 expected_numel = _prod(final_shape) if len(final_shape) > 0 else 1 

122 if out.numel() != expected_numel: 

123 raise RuntimeError("out tensor has incorrect number of elements") 

124 # We will write into a contiguous view; if out isn't contiguous, use a temp and then reshape/copy back 

125 if out.is_contiguous(): 

126 out_row = out.view(M) 

127 out_target = out 

128 else: 

129 out_row = torch.empty((M,), dtype=out.dtype, device=out.device) 

130 out_target = out 

131 

132 # Strides for x_2d (contiguous row-major) 

133 stride_xm = x_2d.stride(0) 

134 stride_xk = x_2d.stride(1) 

135 

136 # Launch kernel 

137 grid = lambda meta: (M,) 

138 BLOCK_SIZE = 1024 

139 amin_reduce_last_kernel[grid]( 

140 x_2d, 

141 out_row, 

142 M, 

143 K, 

144 stride_xm, 

145 stride_xk, 

146 init_val, 

147 BLOCK_SIZE=BLOCK_SIZE, 

148 ) 

149 

150 # Reshape to target final shape 

151 if len(final_shape) == 0: 

152 result = out_row.view(()) 

153 else: 

154 result = out_row.view(final_shape) 

155 

156 if out_target is not None: 

157 # If original 'out' was non-contiguous, copy result into it respecting shape 

158 if not out_target.is_contiguous(): 

159 # Copy into the provided 'out' tensor 

160 out_target.copy_(result) 

161 return out_target 

162 return out_target 

163 return result 

164 

165 

166def amin(*args, **kwargs): 

167 # Parse args to match aten.amin 

168 if len(args) == 0: 

169 raise RuntimeError("amin requires at least one tensor argument") 

170 x = args[0] 

171 dim = kwargs.get("dim", None) 

172 keepdim = kwargs.get("keepdim", False) 

173 

174 # Positional handling: amin(x, dim), amin(x, dim, keepdim) 

175 if len(args) >= 2: 

176 if isinstance(args[1], (int, list, tuple)): 

177 dim = args[1] 

178 elif isinstance(args[1], bool): 

179 keepdim = args[1] 

180 if len(args) >= 3: 

181 if isinstance(args[2], bool): 

182 keepdim = args[2] 

183 

184 return _amin_impl(x, dim=dim, keepdim=keepdim, out=None) 

185 

186 

187def amin_out(*args, **kwargs): 

188 # Expected signature: amin_out(x, dim, keepdim, out) or with out as kwarg 

189 if len(args) == 0: 

190 raise RuntimeError("amin_out requires at least one tensor argument") 

191 x = args[0] 

192 

193 # Extract out 

194 out = kwargs.get("out", None) 

195 dim = kwargs.get("dim", None) 

196 keepdim = kwargs.get("keepdim", False) 

197 

198 # Positional arguments 

199 # Try to detect out as last positional if provided 

200 if len(args) >= 2: 

201 if isinstance(args[1], (int, list, tuple)): 

202 dim = args[1] 

203 elif isinstance(args[1], bool): 

204 keepdim = args[1] 

205 elif isinstance(args[1], torch.Tensor): 

206 out = args[1] 

207 if len(args) >= 3: 

208 if isinstance(args[2], bool): 

209 keepdim = args[2] 

210 elif isinstance(args[2], torch.Tensor) and out is None: 

211 out = args[2] 

212 if len(args) >= 4 and out is None and isinstance(args[3], torch.Tensor): 

213 out = args[3] 

214 

215 if out is None: 

216 raise RuntimeError("amin_out requires an 'out' tensor argument") 

217 

218 return _amin_impl(x, dim=dim, keepdim=keepdim, out=out)