Coverage for src/flag_gems/runtime/backend/_mthreads/ops/any.py: 0%

135 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-16 02:02 +0800

1import logging 

2import math 

3from typing import Sequence 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13logger = logging.getLogger( 

14 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}' 

15) 

16 

17 

18def _flatten_dim(shape: Sequence[int], dim: int): 

19 dim = dim % len(shape) 

20 n = shape[dim] 

21 inner = math.prod(shape[dim + 1 :]) if dim + 1 < len(shape) else 1 

22 outer = math.prod(shape[:dim]) if dim > 0 else 1 

23 return dim, n, inner, outer 

24 

25 

26# Favor smaller tiles to keep occupancy high on MUSA; wide tiles trigger register 

27# pressure and hurt latency for large reductions. 

28def _select_reduction_config(m_rows: int, n_cols: int): 

29 block_n = min(256, max(64, 1 << int(math.ceil(math.log2(n_cols))))) 

30 max_block_m = 1 << int(math.floor(math.log2(max(1, m_rows)))) 

31 block_m = min(32, max_block_m) 

32 num_warps = 8 if block_n >= 256 else 4 

33 return block_m, block_n, num_warps 

34 

35 

36@libentry() 

37@triton.jit 

38def any_kernel_dim( 

39 inp, 

40 out, 

41 M, 

42 N, 

43 BLOCK_M: tl.constexpr, 

44 BLOCK_N: tl.constexpr, 

45): 

46 pid = tle.program_id(0) 

47 rows = (pid * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64) 

48 row_mask = rows < M 

49 row_offsets = rows * N 

50 

51 acc = tl.zeros((BLOCK_M,), dtype=tl.int1) 

52 for off in range(0, N, BLOCK_N): 

53 cols = off + tl.arange(0, BLOCK_N) 

54 col_mask = cols < N 

55 active = acc == 0 

56 mask = row_mask[:, None] & col_mask[None, :] & active[:, None] 

57 vals = tl.load(inp + row_offsets[:, None] + cols[None, :], mask=mask, other=0.0) 

58 block_any = tl.max(vals != 0, axis=1).to(tl.int1) 

59 acc = acc | block_any 

60 tl.store(out + rows, acc, mask=row_mask) 

61 

62 

63@triton.jit 

64def any_kernel_dim_strided( 

65 inp, 

66 out, 

67 M, 

68 N, 

69 INNER, 

70 STRIDE_OUTER, 

71 STRIDE_REDUCE, 

72 BLOCK_M: tl.constexpr, 

73 BLOCK_N: tl.constexpr, 

74): 

75 pid = tle.program_id(0) 

76 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M) 

77 rows = rows.to(tl.int64) 

78 row_mask = rows < M 

79 

80 outer_idx = rows // INNER 

81 inner_idx = rows % INNER 

82 base_ptr = inp + outer_idx * STRIDE_OUTER + inner_idx 

83 

84 acc = tl.zeros((BLOCK_M,), dtype=tl.int1) 

85 for off in range(0, N, BLOCK_N): 

86 cols = off + tl.arange(0, BLOCK_N) 

87 cols = cols.to(tl.int64) 

88 col_mask = cols < N 

89 active = acc == 0 

90 mask = row_mask[:, None] & col_mask[None, :] & active[:, None] 

91 vals = tl.load( 

92 base_ptr[:, None] + cols[None, :] * STRIDE_REDUCE, mask=mask, other=0.0 

93 ) 

94 block_any = tl.max(vals != 0, axis=1).to(tl.int1) 

95 acc = acc | block_any 

96 tl.store(out + rows, acc, mask=row_mask) 

97 

98 

99@libentry() 

100@triton.jit 

101def any_kernel_1( 

102 inp, 

103 mid, 

104 n_elements, 

105 mid_size, 

106 BLOCK_SIZE: tl.constexpr, 

107): 

108 pid = tle.program_id(0) 

109 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

110 inp_ptrs = inp + offset 

111 mask = offset < n_elements 

112 inp_val = tl.load(inp_ptrs, mask=mask, other=0.0) 

113 any_val = tl.max(inp_val != 0, axis=0) 

114 mid_ptr = mid + pid 

115 tl.store(mid_ptr, any_val) 

116 

117 

118@libentry() 

119@triton.jit 

120def any_kernel_2(mid, out, MID_SIZE, BLOCK_MID: tl.constexpr): 

121 offset = tl.arange(0, BLOCK_MID) 

122 mid_ptrs = mid + offset 

123 mask = offset < MID_SIZE 

124 mid_val = tl.load(mid_ptrs, mask=mask, other=0).to(tl.int1) 

125 any_val = tl.max(mid_val, axis=0) 

126 tl.store(out, any_val) 

127 

128 

129def any(inp): 

130 logger.debug("GEMS_MTHREADS ANY") 

131 n_elements = inp.numel() 

132 block_size = triton.next_power_of_2(math.ceil(math.sqrt(n_elements))) 

133 block_size = min(block_size * 2, 4096, triton.next_power_of_2(n_elements)) 

134 mid_size = triton.cdiv(n_elements, block_size) 

135 block_mid = triton.next_power_of_2(mid_size) 

136 

137 mid = torch.empty((mid_size,), dtype=torch.bool, device=inp.device) 

138 out = torch.empty([], dtype=torch.bool, device=inp.device) 

139 

140 num_warps_block = min(8, max(1, block_size // 128)) 

141 num_warps_mid = min(8, max(1, block_mid // 128)) 

142 

143 with torch_device_fn.device(inp.device): 

144 any_kernel_1[(mid_size, 1)]( 

145 inp, 

146 mid, 

147 n_elements, 

148 mid_size, 

149 block_size, 

150 num_warps=num_warps_block, 

151 num_stages=2, 

152 ) 

153 any_kernel_2[(1, 1)]( 

154 mid, 

155 out, 

156 mid_size, 

157 block_mid, 

158 num_warps=num_warps_mid, 

159 num_stages=2, 

160 ) 

161 

162 return out 

163 

164 

165def triton_any_dim_strided( 

166 inp: torch.Tensor, dim: int, keepdim: bool = False 

167) -> torch.Tensor: 

168 dim = dim % inp.ndim 

169 shape = list(inp.shape) 

170 dim, n, inner, outer = _flatten_dim(shape, dim) 

171 m = outer * inner 

172 

173 stride = inp.stride() 

174 stride_reduce = stride[dim] 

175 stride_outer = stride_reduce * n 

176 

177 out_flat = torch.empty((m,), dtype=torch.bool, device=inp.device) 

178 block_m, block_n, num_warps = _select_reduction_config(m, n) 

179 grid = (triton.cdiv(m, block_m),) 

180 with torch_device_fn.device(inp.device): 

181 any_kernel_dim_strided[grid]( 

182 inp, 

183 out_flat, 

184 m, 

185 n, 

186 inner, 

187 stride_outer, 

188 stride_reduce, 

189 BLOCK_M=block_m, 

190 BLOCK_N=block_n, 

191 num_warps=num_warps, 

192 num_stages=2, 

193 ) 

194 

195 shape[dim] = 1 

196 out = out_flat.view(shape) 

197 if not keepdim: 

198 out = out.squeeze(dim=dim) 

199 return out 

200 

201 

202def any_dim(inp, dim=None, keepdim=False): 

203 logger.debug("GEMS_MTHREADS ANY DIM") 

204 # shape = list(inp.shape) 

205 if dim is None: 

206 out = any(inp) 

207 if keepdim: 

208 out = torch.reshape(out, [1] * inp.ndim) 

209 return out 

210 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

211 return triton_any_dim_strided(inp, dim, keepdim=keepdim) 

212 

213 

214def any_dims(inp, dim=None, keepdim=False): 

215 logger.debug("GEMS_MTHREADS ANY DIMS") 

216 

217 if dim is None or isinstance(dim, int): 

218 return any_dim(inp, dim=dim, keepdim=keepdim) 

219 

220 dims = [d % inp.ndim for d in dim] 

221 dims = sorted(set(dims)) 

222 out = inp 

223 for d in dims: 

224 out = triton_any_dim_strided(out, d, keepdim=True) 

225 if not keepdim: 

226 for d in reversed(dims): 

227 out = out.squeeze(dim=d) 

228 return out 

229 

230 

231__all__ = ["any", "any_dim", "any_dims"]