Coverage for src/flag_gems/ops/cat.py: 68%

109 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-26 15:32 +0800

1import logging 

2from typing import List, Tuple, Union 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8logger = logging.getLogger(__name__) 

9 

10 

11@triton.jit 

12def cat_copy_func_kernel_4( 

13 out_ptr, 

14 in_ptr_a, 

15 in_ptr_b, 

16 in_ptr_c, 

17 in_ptr_d, 

18 dim_size_in_a, 

19 dim_size_in_b, 

20 dim_size_in_c, 

21 dim_size_in_d, 

22 dim_size_out, 

23 dim_prod_post, 

24 dim_offset_a, 

25 dim_offset_b, 

26 dim_offset_c, 

27 dim_offset_d, 

28 total_elements_a, 

29 total_elements_b, 

30 total_elements_c, 

31 total_elements_d, 

32 BLOCK_X: tl.constexpr, 

33): 

34 pid_x = tl.program_id(0) 

35 pid_y = tl.program_id(1) 

36 

37 if pid_y == 0: 

38 in_ptr = in_ptr_a 

39 dim_size_in = dim_size_in_a 

40 dim_offset = dim_offset_a 

41 total_elements = total_elements_a 

42 elif pid_y == 1: 

43 in_ptr = in_ptr_b 

44 dim_size_in = dim_size_in_b 

45 dim_offset = dim_offset_b 

46 total_elements = total_elements_b 

47 elif pid_y == 2: 

48 in_ptr = in_ptr_c 

49 dim_size_in = dim_size_in_c 

50 dim_offset = dim_offset_c 

51 total_elements = total_elements_c 

52 else: 

53 in_ptr = in_ptr_d 

54 dim_size_in = dim_size_in_d 

55 dim_offset = dim_offset_d 

56 total_elements = total_elements_d 

57 

58 block_start = pid_x * BLOCK_X 

59 offsets = tl.arange(0, BLOCK_X) 

60 mask = block_start + offsets < total_elements 

61 

62 idx = block_start + offsets 

63 

64 pre_idx = idx // (dim_size_in * dim_prod_post) 

65 dim_idx = (idx // dim_prod_post) % dim_size_in 

66 post_idx = idx % dim_prod_post 

67 

68 out_idx = ( 

69 pre_idx * dim_size_out * dim_prod_post 

70 + (dim_idx + dim_offset) * dim_prod_post 

71 + post_idx 

72 ) 

73 

74 data = tl.load(in_ptr + idx, mask=mask) 

75 tl.store(out_ptr + out_idx, data, mask=mask) 

76 

77 

78def cat( 

79 A: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int = 0 

80) -> torch.Tensor: 

81 logger.debug("GEMS CAT") 

82 if len(A) == 0: 

83 raise RuntimeError("torch.cat(): expected a non-empty list of Tensors") 

84 if len(A) == 1: 

85 return A[0] 

86 

87 # remove torch.Size([0]) tensors 

88 device = A[0].device 

89 dtype = A[0].dtype 

90 A = list(A) 

91 for i in range(len(A) - 1, -1, -1): 

92 if A[i].shape == torch.Size([0]): 

93 A.pop(i) 

94 if len(A) == 0: 

95 return torch.tensor([], device=device, dtype=dtype) 

96 elif len(A) == 1: 

97 return A[0] 

98 

99 assert dim >= -A[0].ndim and dim < A[0].ndim, f"Invalid dim: {dim}" 

100 dim %= A[0].ndim 

101 

102 # Same rank check 

103 inp_shapes = [list(_.shape) for _ in A] 

104 inp0_shape = inp_shapes[0] 

105 for s in inp_shapes[1:]: 

106 if len(s) != len(inp0_shape): 

107 raise RuntimeError( 

108 f"Tensors must have same number of dimensions: got {len(inp0_shape)} and {len(s)}" 

109 ) 

110 for tensor_idx, inp_shape in enumerate(inp_shapes): 

111 for idx, (common_length, length) in enumerate(zip(inp0_shape, inp_shape)): 

112 if idx != dim and length != common_length: 

113 raise RuntimeError( 

114 f"Sizes of tensors must match except in dimension {dim}. " 

115 f"Expected size {common_length} but got size {length} for tensor number " 

116 f"{tensor_idx} in the list" 

117 ) 

118 

119 # Type promotion: find the common dtype for all tensors 

120 device = A[0].device 

121 dtypes = [t.dtype for t in A] 

122 dtype = dtypes[0] 

123 for dt in dtypes[1:]: 

124 dtype = torch.promote_types(dtype, dt) 

125 # Convert all tensors to the common dtype if needed 

126 A = [t.to(dtype) if t.dtype != dtype else t for t in A] 

127 

128 shapes = [t.shape for t in A] 

129 cat_dim_sizes = [s[dim] for s in shapes] 

130 out_shape = list(shapes[0]) 

131 out_shape[dim] = sum(cat_dim_sizes) 

132 out = torch.empty(out_shape, dtype=dtype, device=device) 

133 

134 BLOCK = 1024 

135 dim_offset = 0 

136 

137 i = 0 

138 while i < len(A): 

139 tensors_in_batch = A[i : i + 4] 

140 num_tensors_in_batch = len(tensors_in_batch) 

141 

142 args = [] 

143 total_elements_list = [] 

144 current_dim_offset = dim_offset 

145 

146 for j in range(4): 

147 if j < num_tensors_in_batch: 

148 tensor = tensors_in_batch[j].contiguous() 

149 shape = tensor.shape 

150 total_elements = tensor.numel() 

151 dim_size_in = shape[dim] 

152 

153 args.extend([tensor, dim_size_in, current_dim_offset, total_elements]) 

154 total_elements_list.append(total_elements) 

155 current_dim_offset += dim_size_in 

156 else: 

157 # Add placeholders for unused tensor slots 

158 args.extend([tensors_in_batch[0], 0, 0, 0]) 

159 total_elements_list.append(0) 

160 

161 dim_size_out = out_shape[dim] 

162 dim_prod_post = 1 

163 for d in range(dim + 1, A[0].ndim): 

164 dim_prod_post *= A[0].shape[d] 

165 

166 grid_y = num_tensors_in_batch 

167 max_elements_in_batch = max(total_elements_list) if total_elements_list else 0 

168 grid = (triton.cdiv(max_elements_in_batch, BLOCK), grid_y) 

169 

170 ( 

171 tensor_a, 

172 dim_size_in_a, 

173 dim_offset_a, 

174 total_elements_a, 

175 tensor_b, 

176 dim_size_in_b, 

177 dim_offset_b, 

178 total_elements_b, 

179 tensor_c, 

180 dim_size_in_c, 

181 dim_offset_c, 

182 total_elements_c, 

183 tensor_d, 

184 dim_size_in_d, 

185 dim_offset_d, 

186 total_elements_d, 

187 ) = args 

188 

189 cat_copy_func_kernel_4[grid]( 

190 out, 

191 tensor_a, 

192 tensor_b, 

193 tensor_c, 

194 tensor_d, 

195 dim_size_in_a, 

196 dim_size_in_b, 

197 dim_size_in_c, 

198 dim_size_in_d, 

199 dim_size_out, 

200 dim_prod_post, 

201 dim_offset_a, 

202 dim_offset_b, 

203 dim_offset_c, 

204 dim_offset_d, 

205 total_elements_a, 

206 total_elements_b, 

207 total_elements_c, 

208 total_elements_d, 

209 BLOCK_X=BLOCK, 

210 ) 

211 

212 dim_offset = current_dim_offset 

213 i += num_tensors_in_batch 

214 

215 return out