Coverage for src/flag_gems/runtime/backend/_mthreads/ops/all.py: 0%

140 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1import logging 

2import math 

3from typing import Sequence 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems import runtime 

10from flag_gems.runtime import torch_device_fn 

11from flag_gems.utils import dim_compress, libentry, libtuner 

12from flag_gems.utils import triton_lang_extension as tle 

13 

14logger = logging.getLogger( 

15 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}' 

16) 

17 

18NAIVE_REDUCTION_CONFIGS = [ 

19 triton.Config({"BLOCK_M": 4, "BLOCK_N": 1024}, num_warps=4), 

20 triton.Config({"BLOCK_M": 8, "BLOCK_N": 1024}, num_warps=4), 

21 triton.Config({"BLOCK_M": 16, "BLOCK_N": 1024}, num_warps=8), 

22 triton.Config({"BLOCK_M": 32, "BLOCK_N": 512}, num_warps=8), 

23 triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_warps=4), 

24 triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=4), 

25 triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=4), 

26 triton.Config({"BLOCK_M": 16, "BLOCK_N": 256}, num_warps=4), 

27 triton.Config({"BLOCK_M": 16, "BLOCK_N": 128}, num_warps=4), 

28 triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_warps=8), 

29 triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=4), 

30 triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4), 

31 triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=8), 

32 triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_warps=4), 

33 triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=8), 

34 triton.Config({"BLOCK_M": 256, "BLOCK_N": 256}, num_warps=8), 

35] 

36 

37 

38@triton.jit 

39def reduce_all(a, b): 

40 return a and b 

41 

42 

43@triton.autotune(configs=NAIVE_REDUCTION_CONFIGS, key=["M", "N"]) 

44@triton.jit 

45def all_kernel_dim_strided( 

46 inp, 

47 out, 

48 M, 

49 N, 

50 INNER, 

51 STRIDE_OUTER, 

52 STRIDE_REDUCE, 

53 BLOCK_M: tl.constexpr, 

54 BLOCK_N: tl.constexpr, 

55): 

56 pid = tl.program_id(0) 

57 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M) 

58 rows = rows.to(tl.int64) 

59 row_mask = rows < M 

60 

61 outer_idx = rows // INNER 

62 inner_idx = rows % INNER 

63 base_ptr = inp + outer_idx * STRIDE_OUTER + inner_idx 

64 

65 acc = tl.full([BLOCK_M, BLOCK_N], value=1, dtype=tl.int1) 

66 for off in range(0, N, BLOCK_N): 

67 cols = off + tl.arange(0, BLOCK_N) 

68 cols = cols.to(tl.int64) 

69 col_mask = cols < N 

70 mask = row_mask[:, None] and col_mask[None, :] 

71 vals = tl.load( 

72 base_ptr[:, None] + cols[None, :] * STRIDE_REDUCE, mask, other=1.0 

73 ) 

74 acc = acc and (vals != 0) 

75 all_val = tl.reduce(acc, axis=1, combine_fn=reduce_all) 

76 tl.store(out + rows, all_val, mask=row_mask) 

77 

78 

79def _flatten_dim(shape: Sequence[int], dim: int): 

80 dim = dim % len(shape) 

81 n = shape[dim] 

82 inner = math.prod(shape[dim + 1 :]) if dim + 1 < len(shape) else 1 

83 outer = math.prod(shape[:dim]) if dim > 0 else 1 

84 return dim, n, inner, outer 

85 

86 

87def triton_all_dim_strided( 

88 inp: torch.Tensor, dim: int, keepdim: bool = False 

89) -> torch.Tensor: 

90 dim = dim % inp.ndim 

91 shape = list(inp.shape) 

92 dim, n, inner, outer = _flatten_dim(shape, dim) 

93 m = outer * inner 

94 

95 stride = inp.stride() 

96 stride_reduce = stride[dim] 

97 stride_outer = stride_reduce * n 

98 

99 out_flat = torch.empty((m,), dtype=torch.bool, device=inp.device) 

100 grid = lambda meta: (triton.cdiv(m, meta["BLOCK_M"]),) 

101 all_kernel_dim_strided[grid]( 

102 inp, 

103 out_flat, 

104 m, 

105 n, 

106 inner, 

107 stride_outer, 

108 stride_reduce, 

109 ) 

110 

111 shape[dim] = 1 

112 out = out_flat.view(shape) 

113 if not keepdim: 

114 out = out.squeeze(dim=dim) 

115 return out 

116 

117 

118@libentry() 

119@libtuner( 

120 configs=runtime.get_tuned_config("naive_reduction"), 

121 key=["M", "N"], 

122) 

123@triton.jit 

124def all_kernel_dim( 

125 inp, 

126 out, 

127 M, 

128 N, 

129 BLOCK_M: tl.constexpr, 

130 BLOCK_N: tl.constexpr, 

131): 

132 pid = tle.program_id(0) 

133 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

134 inp = inp + rows * N 

135 out = out + rows 

136 row_mask = rows < M 

137 

138 _all = tl.full([BLOCK_M, BLOCK_N], value=1, dtype=tl.int1) 

139 for off in range(0, N, BLOCK_N): 

140 cols = off + tl.arange(0, BLOCK_N)[None, :] 

141 col_mask = cols < N 

142 mask = row_mask and col_mask 

143 

144 a = tl.load(inp + cols, mask, other=1.0) 

145 _all = _all and (a != 0) 

146 all = tl.reduce(_all, axis=1, combine_fn=reduce_all) 

147 tl.store(out, all[:, None], row_mask) 

148 

149 

150@libentry() 

151@triton.jit 

152def all_kernel_1( 

153 inp, 

154 mid, 

155 n_elements, 

156 mid_size, 

157 BLOCK_SIZE: tl.constexpr, 

158): 

159 pid = tle.program_id(0) 

160 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

161 inp_ptrs = inp + offset 

162 mask = offset < n_elements 

163 inp_val = tl.load(inp_ptrs, mask=mask, other=1.0) 

164 all_val = tl.reduce(inp_val != 0, axis=0, combine_fn=reduce_all) 

165 mid_ptr = mid + pid 

166 tl.store(mid_ptr, all_val) 

167 

168 

169@libentry() 

170@triton.jit 

171def all_kernel_2(mid, out, MID_SIZE, BLOCK_MID: tl.constexpr): 

172 offset = tl.arange(0, BLOCK_MID) 

173 mid_ptrs = mid + offset 

174 mask = offset < MID_SIZE 

175 mid_val = tl.load(mid_ptrs, mask=mask, other=1).to(tl.int1) 

176 all_val = tl.reduce(mid_val, axis=0, combine_fn=reduce_all) 

177 tl.store(out, all_val) 

178 

179 

180def all(inp): 

181 logger.debug("GEMS_MTHREADS ALL") 

182 n_elements = inp.numel() 

183 block_size = triton.next_power_of_2(math.ceil(math.sqrt(n_elements))) 

184 block_size = min(block_size * 2, 4096, triton.next_power_of_2(n_elements)) 

185 mid_size = triton.cdiv(n_elements, block_size) 

186 block_mid = triton.next_power_of_2(mid_size) 

187 

188 mid = torch.empty((mid_size,), dtype=torch.bool, device=inp.device) 

189 out = torch.empty([], dtype=torch.bool, device=inp.device) 

190 

191 with torch_device_fn.device(inp.device): 

192 all_kernel_1[(mid_size, 1)](inp, mid, n_elements, mid_size, block_size) 

193 all_kernel_2[(1, 1)](mid, out, mid_size, block_mid) 

194 

195 return out 

196 

197 

198def all_dim(inp, dim=None, keepdim=False): 

199 logger.debug("GEMS_MTHREADS ALL DIM") 

200 if dim is None: 

201 out = all(inp) 

202 if keepdim: 

203 out = torch.reshape(out, [1] * inp.ndim) 

204 return out 

205 

206 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

207 dim = dim % inp.ndim 

208 

209 with torch_device_fn.device(inp.device): 

210 return triton_all_dim_strided(inp, dim=dim, keepdim=keepdim) 

211 

212 

213def all_dims(inp, dim=None, keepdim=False): 

214 logger.debug("GEMS_MTHREADS ALL DIMS") 

215 

216 if dim is None or isinstance(dim, int): 

217 return all_dim(inp, dim=dim, keepdim=keepdim) 

218 assert ((i >= -inp.ndim and i < inp.ndim) for i in dim), "Invalid dim" 

219 

220 shape = list(inp.shape) 

221 dim = [d % inp.ndim for d in dim] 

222 inp = dim_compress(inp, dim) 

223 N = 1 

224 for i in dim: 

225 N *= shape[i] 

226 shape[i] = 1 

227 M = inp.numel() // N 

228 

229 out = torch.empty(shape, dtype=torch.bool, device=inp.device) 

230 

231 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

232 with torch_device_fn.device(inp.device): 

233 all_kernel_dim[grid](inp, out, M, N) 

234 if not keepdim: 

235 out = out.squeeze(dim=dim) 

236 return out 

237 

238 

239__all__ = ["all", "all_dim", "all_dims"]