Coverage for src/flag_gems/ops/vstack.py: 74%
76 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger(__name__)
15@libentry()
16@triton.autotune(
17 configs=runtime.get_tuned_config("vstack"),
18 key=[
19 "max_tile_elems",
20 ],
21)
22@triton.jit
23def vstack_kernel(
24 itensor_ptr0,
25 itensor_ptr1,
26 itensor_ptr2,
27 itensor_ptr3,
28 output_ptr,
29 local_row0,
30 local_row1,
31 local_row2,
32 local_row3,
33 exc_row_offset0,
34 exc_row_offset1,
35 exc_row_offset2,
36 exc_row_offset3,
37 total_row_offset,
38 row_stride,
39 max_tile_elems,
40 BLOCK_SIZE: tl.constexpr,
41):
42 pid_x = tle.program_id(axis=0)
43 tensor_idx = tle.program_id(axis=1)
44 col_idx = tl.arange(0, BLOCK_SIZE)
46 intensor_ptr = tl.where(tensor_idx == 0, itensor_ptr0, itensor_ptr1)
47 intensor_ptr = tl.where(tensor_idx == 2, itensor_ptr2, intensor_ptr)
48 intensor_ptr = tl.where(tensor_idx == 3, itensor_ptr3, intensor_ptr)
49 base_exc_row_idx = tl.where(tensor_idx == 0, exc_row_offset0, exc_row_offset1)
50 base_exc_row_idx = tl.where(tensor_idx == 2, exc_row_offset2, base_exc_row_idx)
51 base_exc_row_idx = tl.where(tensor_idx == 3, exc_row_offset3, base_exc_row_idx)
52 local_row = tl.where(tensor_idx == 0, local_row0, local_row1)
53 local_row = tl.where(tensor_idx == 2, local_row2, local_row)
54 local_row = tl.where(tensor_idx == 3, local_row3, local_row)
56 end_idx = local_row * row_stride.to(tl.int64)
57 idx = (pid_x * BLOCK_SIZE + col_idx).to(tl.int64)
58 offset_mask = idx < end_idx
59 in_offset = intensor_ptr + idx
60 row_stride_offset = (total_row_offset + base_exc_row_idx) * row_stride.to(tl.int64)
61 out_offset = output_ptr + row_stride_offset + idx
62 out = tl.load(in_offset, mask=offset_mask)
63 tl.store(out_offset, out, mask=offset_mask)
66def vstack(tensors: list):
67 logger.debug("GEMS VSTACK")
69 tensors = torch.atleast_2d(tensors)
70 num_tensors = len(tensors)
71 assert num_tensors > 0
73 # Ensure all tensors are on the same device and have the same dtype
74 device = tensors[0].device
75 dtype = tensors[0].dtype
76 for tensor in tensors:
77 assert (
78 tensor.device == device
79 and tensor.dtype == dtype
80 and tensors[0].shape[1:] == tensor.shape[1:]
81 )
83 c_tensors = [t.contiguous() for t in tensors]
84 # Calculate the output shape
85 total_rows = sum(tensor.shape[0] for tensor in c_tensors)
86 output_shape = list(c_tensors[0].shape)
87 output_shape[0] = total_rows
88 output = torch.empty(output_shape, device=device, dtype=dtype)
89 row_stride = c_tensors[0].stride(0)
91 outer_iters = triton.cdiv(num_tensors, 4)
92 total_row_offset = 0
93 for i in range(outer_iters):
94 max_rows = 1
95 itensors = []
96 exclusive_row = []
97 local_row = []
98 array_row_offset = 0
99 scheduled_num_tensors = 0
100 for j in range(4):
101 tensor_idx = i * 4 + j
102 if tensor_idx < num_tensors:
103 scheduled_num_tensors += 1
104 itensors.append(c_tensors[tensor_idx])
105 local_row.append(c_tensors[tensor_idx].shape[0])
106 exclusive_row.append(array_row_offset)
107 array_row_offset += c_tensors[tensor_idx].shape[0]
108 max_rows = max(max_rows, c_tensors[tensor_idx].shape[0])
109 else:
110 empty_tensor = torch.empty(
111 0, dtype=c_tensors[0].dtype, device=c_tensors[0].device
112 )
113 itensors.append(empty_tensor)
114 local_row.append(local_row[-1])
115 exclusive_row.append(exclusive_row[-1])
116 max_tile_elems = max_rows * row_stride
117 grid = lambda META: (
118 triton.cdiv(max_tile_elems, META["BLOCK_SIZE"]),
119 scheduled_num_tensors,
120 )
121 # Launch the kernel
122 with torch_device_fn.device(c_tensors[0].device):
123 vstack_kernel[grid](
124 itensors[0],
125 itensors[1],
126 itensors[2],
127 itensors[3],
128 output,
129 local_row[0],
130 local_row[1],
131 local_row[2],
132 local_row[3],
133 exclusive_row[0],
134 exclusive_row[1],
135 exclusive_row[2],
136 exclusive_row[3],
137 total_row_offset,
138 row_stride,
139 max_tile_elems,
140 )
141 total_row_offset += array_row_offset
142 return output