Coverage for src/flag_gems/ops/stack.py: 60%
86 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
1import logging
2from typing import List, Tuple, Union
4import torch
5import triton
6import triton.language as tl
8logger = logging.getLogger(__name__)
11@triton.jit
12def stack_copy_func_kernel_4(
13 out_ptr,
14 in_ptr_a,
15 in_ptr_b,
16 in_ptr_c,
17 in_ptr_d,
18 dim_size_in_a,
19 dim_size_in_b,
20 dim_size_in_c,
21 dim_size_in_d,
22 dim_size_out,
23 dim_prod_post,
24 dim_offset_a,
25 dim_offset_b,
26 dim_offset_c,
27 dim_offset_d,
28 total_elements_a,
29 total_elements_b,
30 total_elements_c,
31 total_elements_d,
32 BLOCK_X: tl.constexpr,
33):
34 pid_x = tl.program_id(0)
35 pid_y = tl.program_id(1)
37 if pid_y == 0:
38 in_ptr = in_ptr_a
39 dim_size_in = dim_size_in_a
40 dim_offset = dim_offset_a
41 total_elements = total_elements_a
42 elif pid_y == 1:
43 in_ptr = in_ptr_b
44 dim_size_in = dim_size_in_b
45 dim_offset = dim_offset_b
46 total_elements = total_elements_b
47 elif pid_y == 2:
48 in_ptr = in_ptr_c
49 dim_size_in = dim_size_in_c
50 dim_offset = dim_offset_c
51 total_elements = total_elements_c
52 else:
53 in_ptr = in_ptr_d
54 dim_size_in = dim_size_in_d
55 dim_offset = dim_offset_d
56 total_elements = total_elements_d
58 block_start = pid_x * BLOCK_X
59 offsets = tl.arange(0, BLOCK_X)
60 mask = block_start + offsets < total_elements
62 idx = block_start + offsets
64 pre_idx = idx // (dim_size_in * dim_prod_post)
65 post_idx = idx % dim_prod_post
66 pre_idx = idx // dim_prod_post
68 out_idx = (
69 pre_idx * dim_size_out * dim_prod_post + dim_offset * dim_prod_post + post_idx
70 )
72 data = tl.load(in_ptr + idx, mask=mask)
73 tl.store(out_ptr + out_idx, data, mask=mask)
76def stack(
77 tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int = 0
78) -> torch.Tensor:
79 logger.debug("GEMS STACK")
81 if len(tensors) == 0:
82 raise RuntimeError("stack expected a non-empty TensorList")
84 inp_shapes = [list(_.shape) for _ in tensors]
85 inp0_shape = inp_shapes[0]
86 for i, s in enumerate(inp_shapes[1:]):
87 if (dim < -tensors[i + 1].dim() - 1) or (dim > tensors[i + 1].dim()):
88 raise IndexError(
89 "Dimension out of range (expected to be in range of [{}, {}], but got {})".format(
90 -tensors[i + 1].dim() - 1, tensors[i + 1].dim(), dim
91 )
92 )
93 if s != inp0_shape:
94 raise RuntimeError(
95 f"stack expects each tensor to be equal size, but got {inp0_shape} at entry 0 and {s} at entry {i + 1}"
96 )
98 if dim < 0:
99 dim = dim + len(inp0_shape) + 1
101 # Type promotion: find the common dtype for all tensors
102 dtypes = [t.dtype for t in tensors]
103 dtype = dtypes[0]
104 for dt in dtypes[1:]:
105 dtype = torch.promote_types(dtype, dt)
106 # Convert all tensors to the result dtype if needed
107 tensors = [t.to(dtype) if t.dtype != dtype else t for t in tensors]
108 device = tensors[0].device
109 out_shape = inp0_shape[:dim] + [len(tensors)] + inp0_shape[dim:]
110 out = torch.empty(out_shape, dtype=dtype, device=device)
112 dim_prod_post = 1
113 for s in inp0_shape[dim:]:
114 dim_prod_post *= s
116 BLOCK = 1024
117 i = 0
118 while i < len(tensors):
119 tensors_in_batch = tensors[i : i + 4]
120 num_tensors_in_batch = len(tensors_in_batch)
122 args = []
123 total_elements_list = []
125 for j in range(4):
126 if j < num_tensors_in_batch:
127 tensor = tensors_in_batch[j].contiguous()
128 total_elements = tensor.numel()
129 args.extend([tensor, 1, i + j, total_elements])
130 total_elements_list.append(total_elements)
131 else:
132 args.extend([tensors_in_batch[0], 0, 0, 0])
133 total_elements_list.append(0)
135 dim_size_out = len(tensors)
137 grid_y = num_tensors_in_batch
138 max_elements_in_batch = tensors[0].numel() if total_elements_list else 0
139 grid = (triton.cdiv(max_elements_in_batch, BLOCK), grid_y)
141 (
142 tensor_a,
143 dim_size_in_a,
144 dim_offset_a,
145 total_elements_a,
146 tensor_b,
147 dim_size_in_b,
148 dim_offset_b,
149 total_elements_b,
150 tensor_c,
151 dim_size_in_c,
152 dim_offset_c,
153 total_elements_c,
154 tensor_d,
155 dim_size_in_d,
156 dim_offset_d,
157 total_elements_d,
158 ) = args
160 stack_copy_func_kernel_4[grid](
161 out,
162 tensor_a,
163 tensor_b,
164 tensor_c,
165 tensor_d,
166 dim_size_in_a,
167 dim_size_in_b,
168 dim_size_in_c,
169 dim_size_in_d,
170 dim_size_out,
171 dim_prod_post,
172 dim_offset_a,
173 dim_offset_b,
174 dim_offset_c,
175 dim_offset_d,
176 total_elements_a,
177 total_elements_b,
178 total_elements_c,
179 total_elements_d,
180 BLOCK_X=BLOCK,
181 )
182 i += num_tensors_in_batch
184 return out