Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/vstack.py: 0%
77 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7# from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15def heur_block_size(args):
16 return triton.next_power_of_2(
17 triton.cdiv(args["max_tile_elems"], 12)
18 ) # cluster_num
21@libentry()
22# @triton.autotune(
23# configs=runtime.get_tuned_config("vstack"),
24# key=[
25# "max_tile_elems",
26# ],
27# )
28@triton.heuristics(
29 values={
30 "BLOCK_SIZE": heur_block_size,
31 },
32)
33@triton.jit
34def vstack_kernel(
35 itensor_ptr0,
36 itensor_ptr1,
37 itensor_ptr2,
38 itensor_ptr3,
39 output_ptr,
40 local_row0,
41 local_row1,
42 local_row2,
43 local_row3,
44 exc_row_offset0,
45 exc_row_offset1,
46 exc_row_offset2,
47 exc_row_offset3,
48 total_row_offset,
49 row_stride,
50 max_tile_elems,
51 BLOCK_SIZE: tl.constexpr,
52):
53 pid_x = tle.program_id(axis=0)
54 tensor_idx = tle.program_id(axis=1)
55 col_idx = tl.arange(0, BLOCK_SIZE)
57 intensor_ptr = tl.where(tensor_idx == 0, itensor_ptr0, itensor_ptr1)
58 intensor_ptr = tl.where(tensor_idx == 2, itensor_ptr2, intensor_ptr)
59 intensor_ptr = tl.where(tensor_idx == 3, itensor_ptr3, intensor_ptr)
60 base_exc_row_idx = tl.where(tensor_idx == 0, exc_row_offset0, exc_row_offset1)
61 base_exc_row_idx = tl.where(tensor_idx == 2, exc_row_offset2, base_exc_row_idx)
62 base_exc_row_idx = tl.where(tensor_idx == 3, exc_row_offset3, base_exc_row_idx)
63 local_row = tl.where(tensor_idx == 0, local_row0, local_row1)
64 local_row = tl.where(tensor_idx == 2, local_row2, local_row)
65 local_row = tl.where(tensor_idx == 3, local_row3, local_row)
67 end_idx = local_row * row_stride.to(tl.int64)
68 idx = (pid_x * BLOCK_SIZE + col_idx).to(tl.int64)
69 offset_mask = idx < end_idx
70 in_offset = intensor_ptr + idx
71 row_stride_offset = (total_row_offset + base_exc_row_idx) * row_stride.to(tl.int64)
72 out_offset = output_ptr + row_stride_offset + idx
73 out = tl.load(in_offset, mask=offset_mask)
74 tl.store(out_offset, out, mask=offset_mask)
77def vstack(tensors: list):
78 logger.debug("GEMS VSTACK")
80 tensors = torch.atleast_2d(tensors)
81 num_tensors = len(tensors)
82 assert num_tensors > 0
84 # Ensure all tensors are on the same device and have the same dtype
85 device = tensors[0].device
86 dtype = tensors[0].dtype
87 for tensor in tensors:
88 assert (
89 tensor.device == device
90 and tensor.dtype == dtype
91 and tensors[0].shape[1:] == tensor.shape[1:]
92 )
94 c_tensors = [t.contiguous() for t in tensors]
95 # Calculate the output shape
96 total_rows = sum(tensor.shape[0] for tensor in c_tensors)
97 output_shape = list(c_tensors[0].shape)
98 output_shape[0] = total_rows
99 output = torch.empty(output_shape, device=device, dtype=dtype)
100 row_stride = c_tensors[0].stride(0)
102 outer_iters = triton.cdiv(num_tensors, 4)
103 total_row_offset = 0
104 for i in range(outer_iters):
105 max_rows = 1
106 itensors = []
107 exclusive_row = []
108 local_row = []
109 array_row_offset = 0
110 scheduled_num_tensors = 0
111 for j in range(4):
112 tensor_idx = i * 4 + j
113 if tensor_idx < num_tensors:
114 scheduled_num_tensors += 1
115 itensors.append(c_tensors[tensor_idx])
116 local_row.append(c_tensors[tensor_idx].shape[0])
117 exclusive_row.append(array_row_offset)
118 array_row_offset += c_tensors[tensor_idx].shape[0]
119 max_rows = max(max_rows, c_tensors[tensor_idx].shape[0])
120 else:
121 empty_tensor = torch.empty(
122 0, dtype=c_tensors[0].dtype, device=c_tensors[0].device
123 )
124 itensors.append(empty_tensor)
125 local_row.append(local_row[-1])
126 exclusive_row.append(exclusive_row[-1])
127 max_tile_elems = max_rows * row_stride
128 grid = lambda META: (
129 triton.cdiv(max_tile_elems, META["BLOCK_SIZE"]),
130 scheduled_num_tensors,
131 )
132 # Launch the kernel
133 with torch_device_fn.device(c_tensors[0].device):
134 vstack_kernel[grid](
135 itensors[0],
136 itensors[1],
137 itensors[2],
138 itensors[3],
139 output,
140 local_row[0],
141 local_row[1],
142 local_row[2],
143 local_row[3],
144 exclusive_row[0],
145 exclusive_row[1],
146 exclusive_row[2],
147 exclusive_row[3],
148 total_row_offset,
149 row_stride,
150 max_tile_elems,
151 )
152 total_row_offset += array_row_offset
153 return output