Coverage for src/flag_gems/ops/cat.py: 66%

140 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import logging 

2from typing import List, Tuple, Union 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8logger = logging.getLogger(__name__) 

9 

10 

11@triton.jit 

12def cat_copy_func_kernel_4( 

13 out_ptr, 

14 in_ptr_a, 

15 in_ptr_b, 

16 in_ptr_c, 

17 in_ptr_d, 

18 dim_size_in_a, 

19 dim_size_in_b, 

20 dim_size_in_c, 

21 dim_size_in_d, 

22 dim_size_out, 

23 dim_prod_post, 

24 dim_offset_a, 

25 dim_offset_b, 

26 dim_offset_c, 

27 dim_offset_d, 

28 total_elements_a, 

29 total_elements_b, 

30 total_elements_c, 

31 total_elements_d, 

32 BLOCK_X: tl.constexpr, 

33): 

34 pid_x = tl.program_id(0) 

35 pid_y = tl.program_id(1) 

36 

37 if pid_y == 0: 

38 in_ptr = in_ptr_a 

39 dim_size_in = dim_size_in_a 

40 dim_offset = dim_offset_a 

41 total_elements = total_elements_a 

42 elif pid_y == 1: 

43 in_ptr = in_ptr_b 

44 dim_size_in = dim_size_in_b 

45 dim_offset = dim_offset_b 

46 total_elements = total_elements_b 

47 elif pid_y == 2: 

48 in_ptr = in_ptr_c 

49 dim_size_in = dim_size_in_c 

50 dim_offset = dim_offset_c 

51 total_elements = total_elements_c 

52 else: 

53 in_ptr = in_ptr_d 

54 dim_size_in = dim_size_in_d 

55 dim_offset = dim_offset_d 

56 total_elements = total_elements_d 

57 

58 block_start = pid_x * BLOCK_X 

59 offsets = tl.arange(0, BLOCK_X) 

60 mask = block_start + offsets < total_elements 

61 

62 idx = block_start + offsets 

63 

64 pre_idx = idx // (dim_size_in * dim_prod_post) 

65 dim_idx = (idx // dim_prod_post) % dim_size_in 

66 post_idx = idx % dim_prod_post 

67 

68 out_idx = ( 

69 pre_idx * dim_size_out * dim_prod_post 

70 + (dim_idx + dim_offset) * dim_prod_post 

71 + post_idx 

72 ) 

73 

74 data = tl.load(in_ptr + idx, mask=mask) 

75 tl.store(out_ptr + out_idx, data, mask=mask) 

76 

77 

78def _cat_run_kernel( 

79 A: List[torch.Tensor], 

80 dim: int, 

81 out_shape: List[int], 

82 out: torch.Tensor, 

83): 

84 BLOCK = 1024 

85 dim_offset = 0 

86 i = 0 

87 while i < len(A): 

88 tensors_in_batch = A[i : i + 4] 

89 num_tensors_in_batch = len(tensors_in_batch) 

90 

91 args = [] 

92 total_elements_list = [] 

93 current_dim_offset = dim_offset 

94 

95 for j in range(4): 

96 if j < num_tensors_in_batch: 

97 tensor = tensors_in_batch[j].contiguous() 

98 shape = tensor.shape 

99 total_elements = tensor.numel() 

100 dim_size_in = shape[dim] 

101 

102 args.extend([tensor, dim_size_in, current_dim_offset, total_elements]) 

103 total_elements_list.append(total_elements) 

104 current_dim_offset += dim_size_in 

105 else: 

106 args.extend([tensors_in_batch[0], 0, 0, 0]) 

107 total_elements_list.append(0) 

108 

109 dim_size_out = out_shape[dim] 

110 dim_prod_post = 1 

111 for d in range(dim + 1, A[0].ndim): 

112 dim_prod_post *= A[0].shape[d] 

113 

114 grid_y = num_tensors_in_batch 

115 max_elements_in_batch = max(total_elements_list) if total_elements_list else 0 

116 grid = (triton.cdiv(max_elements_in_batch, BLOCK), grid_y) 

117 

118 ( 

119 tensor_a, 

120 dim_size_in_a, 

121 dim_offset_a, 

122 total_elements_a, 

123 tensor_b, 

124 dim_size_in_b, 

125 dim_offset_b, 

126 total_elements_b, 

127 tensor_c, 

128 dim_size_in_c, 

129 dim_offset_c, 

130 total_elements_c, 

131 tensor_d, 

132 dim_size_in_d, 

133 dim_offset_d, 

134 total_elements_d, 

135 ) = args 

136 

137 cat_copy_func_kernel_4[grid]( 

138 out, 

139 tensor_a, 

140 tensor_b, 

141 tensor_c, 

142 tensor_d, 

143 dim_size_in_a, 

144 dim_size_in_b, 

145 dim_size_in_c, 

146 dim_size_in_d, 

147 dim_size_out, 

148 dim_prod_post, 

149 dim_offset_a, 

150 dim_offset_b, 

151 dim_offset_c, 

152 dim_offset_d, 

153 total_elements_a, 

154 total_elements_b, 

155 total_elements_c, 

156 total_elements_d, 

157 BLOCK_X=BLOCK, 

158 ) 

159 

160 dim_offset = current_dim_offset 

161 i += num_tensors_in_batch 

162 

163 

164def _cat_build_working_list( 

165 A: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int 

166): 

167 """Returns (mode, payload) where mode is 'single'|'empty'|'multi'.""" 

168 if len(A) == 0: 

169 raise RuntimeError("torch.cat(): expected a non-empty list of Tensors") 

170 if len(A) == 1: 

171 return "single", A[0] 

172 

173 device = A[0].device 

174 dtype = A[0].dtype 

175 A = list(A) 

176 for i in range(len(A) - 1, -1, -1): 

177 if A[i].shape == torch.Size([0]): 

178 A.pop(i) 

179 if len(A) == 0: 

180 return "empty", torch.tensor([], device=device, dtype=dtype) 

181 if len(A) == 1: 

182 return "single", A[0] 

183 

184 assert dim >= -A[0].ndim and dim < A[0].ndim, f"Invalid dim: {dim}" 

185 dim %= A[0].ndim 

186 

187 inp_shapes = [list(_.shape) for _ in A] 

188 inp0_shape = inp_shapes[0] 

189 for s in inp_shapes[1:]: 

190 if len(s) != len(inp0_shape): 

191 raise RuntimeError( 

192 f"Tensors must have same number of dimensions: got {len(inp0_shape)} and {len(s)}" 

193 ) 

194 for tensor_idx, inp_shape in enumerate(inp_shapes): 

195 for idx, (common_length, length) in enumerate(zip(inp0_shape, inp_shape)): 

196 if idx != dim and length != common_length: 

197 raise RuntimeError( 

198 f"Sizes of tensors must match except in dimension {dim}. " 

199 f"Expected size {common_length} but got size {length} for tensor number " 

200 f"{tensor_idx} in the list" 

201 ) 

202 

203 dtypes = [t.dtype for t in A] 

204 dtype = dtypes[0] 

205 for dt in dtypes[1:]: 

206 dtype = torch.promote_types(dtype, dt) 

207 A = [t.to(dtype) if t.dtype != dtype else t for t in A] 

208 

209 shapes = [t.shape for t in A] 

210 cat_dim_sizes = [s[dim] for s in shapes] 

211 out_shape = list(shapes[0]) 

212 out_shape[dim] = sum(cat_dim_sizes) 

213 return "multi", (A, dim, out_shape, dtype, A[0].device) 

214 

215 

216def cat_out( 

217 A: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], 

218 dim: int = 0, 

219 *, 

220 out: torch.Tensor, 

221) -> torch.Tensor: 

222 logger.debug("GEMS CAT_OUT") 

223 mode, payload = _cat_build_working_list(A, dim) 

224 if mode == "single": 

225 t = payload 

226 out.resize_(t.shape) 

227 if out.dtype != t.dtype: 

228 out.copy_(t.to(out.dtype)) 

229 else: 

230 out.copy_(t) 

231 return out 

232 if mode == "empty": 

233 t = payload 

234 out.resize_(t.shape) 

235 out.copy_(t) 

236 return out 

237 

238 A, dim, out_shape, dtype, device = payload 

239 if out.dtype != dtype: 

240 raise RuntimeError(f"cat.out: expected out dtype {dtype}, got {out.dtype}") 

241 if list(out.shape) != out_shape: 

242 out.resize_(out_shape) 

243 _cat_run_kernel(A, dim, out_shape, out) 

244 return out 

245 

246 

247def cat( 

248 A: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int = 0 

249) -> torch.Tensor: 

250 logger.debug("GEMS CAT") 

251 mode, payload = _cat_build_working_list(A, dim) 

252 if mode == "single": 

253 return payload 

254 if mode == "empty": 

255 return payload 

256 

257 A, dim, out_shape, dtype, device = payload 

258 out = torch.empty(out_shape, dtype=dtype, device=device) 

259 _cat_run_kernel(A, dim, out_shape, out) 

260 return out