Coverage for src/flag_gems/ops/cat.py: 68%
109 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1import logging
2from typing import List, Tuple, Union
4import torch
5import triton
6import triton.language as tl
8logger = logging.getLogger(__name__)
11@triton.jit
12def cat_copy_func_kernel_4(
13 out_ptr,
14 in_ptr_a,
15 in_ptr_b,
16 in_ptr_c,
17 in_ptr_d,
18 dim_size_in_a,
19 dim_size_in_b,
20 dim_size_in_c,
21 dim_size_in_d,
22 dim_size_out,
23 dim_prod_post,
24 dim_offset_a,
25 dim_offset_b,
26 dim_offset_c,
27 dim_offset_d,
28 total_elements_a,
29 total_elements_b,
30 total_elements_c,
31 total_elements_d,
32 BLOCK_X: tl.constexpr,
33):
34 pid_x = tl.program_id(0)
35 pid_y = tl.program_id(1)
37 if pid_y == 0:
38 in_ptr = in_ptr_a
39 dim_size_in = dim_size_in_a
40 dim_offset = dim_offset_a
41 total_elements = total_elements_a
42 elif pid_y == 1:
43 in_ptr = in_ptr_b
44 dim_size_in = dim_size_in_b
45 dim_offset = dim_offset_b
46 total_elements = total_elements_b
47 elif pid_y == 2:
48 in_ptr = in_ptr_c
49 dim_size_in = dim_size_in_c
50 dim_offset = dim_offset_c
51 total_elements = total_elements_c
52 else:
53 in_ptr = in_ptr_d
54 dim_size_in = dim_size_in_d
55 dim_offset = dim_offset_d
56 total_elements = total_elements_d
58 block_start = pid_x * BLOCK_X
59 offsets = tl.arange(0, BLOCK_X)
60 mask = block_start + offsets < total_elements
62 idx = block_start + offsets
64 pre_idx = idx // (dim_size_in * dim_prod_post)
65 dim_idx = (idx // dim_prod_post) % dim_size_in
66 post_idx = idx % dim_prod_post
68 out_idx = (
69 pre_idx * dim_size_out * dim_prod_post
70 + (dim_idx + dim_offset) * dim_prod_post
71 + post_idx
72 )
74 data = tl.load(in_ptr + idx, mask=mask)
75 tl.store(out_ptr + out_idx, data, mask=mask)
78def cat(
79 A: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int = 0
80) -> torch.Tensor:
81 logger.debug("GEMS CAT")
82 if len(A) == 0:
83 raise RuntimeError("torch.cat(): expected a non-empty list of Tensors")
84 if len(A) == 1:
85 return A[0]
87 # remove torch.Size([0]) tensors
88 device = A[0].device
89 dtype = A[0].dtype
90 A = list(A)
91 for i in range(len(A) - 1, -1, -1):
92 if A[i].shape == torch.Size([0]):
93 A.pop(i)
94 if len(A) == 0:
95 return torch.tensor([], device=device, dtype=dtype)
96 elif len(A) == 1:
97 return A[0]
99 assert dim >= -A[0].ndim and dim < A[0].ndim, f"Invalid dim: {dim}"
100 dim %= A[0].ndim
102 # Same rank check
103 inp_shapes = [list(_.shape) for _ in A]
104 inp0_shape = inp_shapes[0]
105 for s in inp_shapes[1:]:
106 if len(s) != len(inp0_shape):
107 raise RuntimeError(
108 f"Tensors must have same number of dimensions: got {len(inp0_shape)} and {len(s)}"
109 )
110 for tensor_idx, inp_shape in enumerate(inp_shapes):
111 for idx, (common_length, length) in enumerate(zip(inp0_shape, inp_shape)):
112 if idx != dim and length != common_length:
113 raise RuntimeError(
114 f"Sizes of tensors must match except in dimension {dim}. "
115 f"Expected size {common_length} but got size {length} for tensor number "
116 f"{tensor_idx} in the list"
117 )
119 # Type promotion: find the common dtype for all tensors
120 device = A[0].device
121 dtypes = [t.dtype for t in A]
122 dtype = dtypes[0]
123 for dt in dtypes[1:]:
124 dtype = torch.promote_types(dtype, dt)
125 # Convert all tensors to the common dtype if needed
126 A = [t.to(dtype) if t.dtype != dtype else t for t in A]
128 shapes = [t.shape for t in A]
129 cat_dim_sizes = [s[dim] for s in shapes]
130 out_shape = list(shapes[0])
131 out_shape[dim] = sum(cat_dim_sizes)
132 out = torch.empty(out_shape, dtype=dtype, device=device)
134 BLOCK = 1024
135 dim_offset = 0
137 i = 0
138 while i < len(A):
139 tensors_in_batch = A[i : i + 4]
140 num_tensors_in_batch = len(tensors_in_batch)
142 args = []
143 total_elements_list = []
144 current_dim_offset = dim_offset
146 for j in range(4):
147 if j < num_tensors_in_batch:
148 tensor = tensors_in_batch[j].contiguous()
149 shape = tensor.shape
150 total_elements = tensor.numel()
151 dim_size_in = shape[dim]
153 args.extend([tensor, dim_size_in, current_dim_offset, total_elements])
154 total_elements_list.append(total_elements)
155 current_dim_offset += dim_size_in
156 else:
157 # Add placeholders for unused tensor slots
158 args.extend([tensors_in_batch[0], 0, 0, 0])
159 total_elements_list.append(0)
161 dim_size_out = out_shape[dim]
162 dim_prod_post = 1
163 for d in range(dim + 1, A[0].ndim):
164 dim_prod_post *= A[0].shape[d]
166 grid_y = num_tensors_in_batch
167 max_elements_in_batch = max(total_elements_list) if total_elements_list else 0
168 grid = (triton.cdiv(max_elements_in_batch, BLOCK), grid_y)
170 (
171 tensor_a,
172 dim_size_in_a,
173 dim_offset_a,
174 total_elements_a,
175 tensor_b,
176 dim_size_in_b,
177 dim_offset_b,
178 total_elements_b,
179 tensor_c,
180 dim_size_in_c,
181 dim_offset_c,
182 total_elements_c,
183 tensor_d,
184 dim_size_in_d,
185 dim_offset_d,
186 total_elements_d,
187 ) = args
189 cat_copy_func_kernel_4[grid](
190 out,
191 tensor_a,
192 tensor_b,
193 tensor_c,
194 tensor_d,
195 dim_size_in_a,
196 dim_size_in_b,
197 dim_size_in_c,
198 dim_size_in_d,
199 dim_size_out,
200 dim_prod_post,
201 dim_offset_a,
202 dim_offset_b,
203 dim_offset_c,
204 dim_offset_d,
205 total_elements_a,
206 total_elements_b,
207 total_elements_c,
208 total_elements_d,
209 BLOCK_X=BLOCK,
210 )
212 dim_offset = current_dim_offset
213 i += num_tensors_in_batch
215 return out