Coverage for src/flag_gems/ops/hstack.py: 68%
106 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
1import logging
2from typing import List, Tuple, Union
4import torch
5import triton
6import triton.language as tl
8logger = logging.getLogger(__name__)
11@triton.jit
12def hstack_copy_func_kernel_4(
13 out_ptr,
14 in_ptr_a,
15 in_ptr_b,
16 in_ptr_c,
17 in_ptr_d,
18 dim_size_in_a,
19 dim_size_in_b,
20 dim_size_in_c,
21 dim_size_in_d,
22 dim_size_out,
23 dim_prod_post,
24 dim_offset_a,
25 dim_offset_b,
26 dim_offset_c,
27 dim_offset_d,
28 total_elements_a,
29 total_elements_b,
30 total_elements_c,
31 total_elements_d,
32 BLOCK_X: tl.constexpr,
33):
34 pid_x = tl.program_id(0)
35 pid_y = tl.program_id(1)
37 if pid_y == 0:
38 in_ptr = in_ptr_a
39 dim_size_in = dim_size_in_a
40 dim_offset = dim_offset_a
41 total_elements = total_elements_a
42 elif pid_y == 1:
43 in_ptr = in_ptr_b
44 dim_size_in = dim_size_in_b
45 dim_offset = dim_offset_b
46 total_elements = total_elements_b
47 elif pid_y == 2:
48 in_ptr = in_ptr_c
49 dim_size_in = dim_size_in_c
50 dim_offset = dim_offset_c
51 total_elements = total_elements_c
52 else:
53 in_ptr = in_ptr_d
54 dim_size_in = dim_size_in_d
55 dim_offset = dim_offset_d
56 total_elements = total_elements_d
58 block_start = pid_x * BLOCK_X
59 offsets = tl.arange(0, BLOCK_X)
60 mask = block_start + offsets < total_elements
62 idx = block_start + offsets
64 pre_idx = idx // (dim_size_in * dim_prod_post)
65 dim_idx = (idx // dim_prod_post) % dim_size_in
66 post_idx = idx % dim_prod_post
68 out_idx = (
69 pre_idx * dim_size_out * dim_prod_post
70 + (dim_idx + dim_offset) * dim_prod_post
71 + post_idx
72 )
74 data = tl.load(in_ptr + idx, mask=mask)
75 tl.store(out_ptr + out_idx, data, mask=mask)
78def hstack(
79 tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]]
80) -> torch.Tensor:
81 logger.debug("GEMS HSTACK")
83 if len(tensors) == 0:
84 raise RuntimeError("hstack expected a non-empty TensorList")
86 if tensors[0].ndim == 0:
87 tensors[0] = tensors[0].view(1)
88 inp0_shape = tensors[0].shape
89 out_shape = list(inp0_shape)
90 inp_shapes = [inp0_shape]
92 if len(inp0_shape) == 1:
93 dim = 0
94 else:
95 dim = 1
97 for tensor_num, tensor in enumerate(tensors[1:]):
98 if tensor.ndim == 0:
99 tensor = tensor.view(1)
100 if tensor.ndim != tensors[0].ndim:
101 raise RuntimeError(
102 f"Tensors must have same number of dimensions: got {tensors[0].ndim} and {tensor.ndim}"
103 )
105 inp_shape = tensor.shape
106 inp_shapes.append(inp_shape)
108 for i in range(len(inp_shape)):
109 if i != dim and inp_shape[i] != inp0_shape[i]:
110 raise RuntimeError(
111 f"Sizes of tensors must match except in dimension {dim}. \
112 Expected size {inp0_shape[i]} but got size {inp_shape[i]} \
113 for tensor number {tensor_num + 1} in the list."
114 )
116 inp_shapes = [list(_.shape) for _ in tensors]
117 inp0_shape = inp_shapes[0]
119 # Type promotion: find the common dtype for all tensors
120 dtypes = [t.dtype for t in tensors]
121 dtype = dtypes[0]
122 for dt in dtypes[1:]:
123 dtype = torch.promote_types(dtype, dt)
124 # Convert all tensors to the common dtype if needed
125 tensors = [t.to(dtype) if t.dtype != dtype else t for t in tensors]
126 device = tensors[0].device
127 out_shape[dim] = sum(s[dim] for s in inp_shapes)
128 out = torch.empty(out_shape, dtype=dtype, device=device)
130 dim_prod_post = 1
131 for s in inp0_shape[dim:]:
132 dim_prod_post *= s
133 BLOCK = 1024
134 dim_offset = 0
135 i = 0
136 while i < len(tensors):
137 tensors_in_batch = tensors[i : i + 4]
138 num_tensors_in_batch = len(tensors_in_batch)
140 args = []
141 total_elements_list = []
142 current_dim_offset = dim_offset
144 for j in range(4):
145 if j < num_tensors_in_batch:
146 tensor = tensors_in_batch[j].contiguous()
147 shape = tensor.shape
148 total_elements = tensor.numel()
149 dim_size_in = shape[dim]
151 args.extend([tensor, dim_size_in, current_dim_offset, total_elements])
152 total_elements_list.append(total_elements)
153 current_dim_offset += dim_size_in
154 else:
155 # Add placeholders for unused tensor slots
156 args.extend([tensors_in_batch[0], 0, 0, 0])
157 total_elements_list.append(0)
159 dim_size_out = out_shape[dim]
160 dim_prod_post = 1
161 for d in range(dim + 1, tensors[0].ndim):
162 dim_prod_post *= tensors[0].shape[d]
164 grid_y = num_tensors_in_batch
165 max_elements_in_batch = max(total_elements_list) if total_elements_list else 0
166 grid = (triton.cdiv(max_elements_in_batch, BLOCK), grid_y)
168 (
169 tensor_a,
170 dim_size_in_a,
171 dim_offset_a,
172 total_elements_a,
173 tensor_b,
174 dim_size_in_b,
175 dim_offset_b,
176 total_elements_b,
177 tensor_c,
178 dim_size_in_c,
179 dim_offset_c,
180 total_elements_c,
181 tensor_d,
182 dim_size_in_d,
183 dim_offset_d,
184 total_elements_d,
185 ) = args
187 hstack_copy_func_kernel_4[grid](
188 out,
189 tensor_a,
190 tensor_b,
191 tensor_c,
192 tensor_d,
193 dim_size_in_a,
194 dim_size_in_b,
195 dim_size_in_c,
196 dim_size_in_d,
197 dim_size_out,
198 dim_prod_post,
199 dim_offset_a,
200 dim_offset_b,
201 dim_offset_c,
202 dim_offset_d,
203 total_elements_a,
204 total_elements_b,
205 total_elements_c,
206 total_elements_d,
207 BLOCK_X=BLOCK,
208 )
210 dim_offset = current_dim_offset
211 i += num_tensors_in_batch
213 return out