Coverage for src/flag_gems/runtime/backend/_ascend/ops/vstack.py: 0%
80 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
15@libentry()
16@triton.autotune(
17 configs=runtime.get_tuned_config("vstack"),
18 key=[
19 "max_tile_elems",
20 ],
21)
22@triton.jit
23def vstack_kernel(
24 itensor_ptr0,
25 itensor_ptr1,
26 itensor_ptr2,
27 itensor_ptr3,
28 output_ptr,
29 local_row0,
30 local_row1,
31 local_row2,
32 local_row3,
33 exc_row_offset0,
34 exc_row_offset1,
35 exc_row_offset2,
36 exc_row_offset3,
37 total_row_offset,
38 row_stride,
39 max_tile_elems,
40 BLOCK_SIZE: tl.constexpr,
41):
42 pid_x = tle.program_id(axis=0)
43 tensor_idx = tle.program_id(axis=1)
44 col_idx = tl.arange(0, BLOCK_SIZE)
46 # create a mask to select a corresponding tensor
47 mask0 = tensor_idx == 0
48 mask1 = tensor_idx == 1
49 mask2 = tensor_idx == 2
50 mask3 = tensor_idx == 3
52 # using mask and mathematical operations to select parameters
53 base_exc_row_idx = (
54 mask0 * exc_row_offset0
55 + mask1 * exc_row_offset1
56 + mask2 * exc_row_offset2
57 + mask3 * exc_row_offset3
58 )
60 local_row = (
61 mask0 * local_row0
62 + mask1 * local_row1
63 + mask2 * local_row2
64 + mask3 * local_row3
65 )
67 end_idx = local_row * row_stride.to(tl.int64)
68 idx = (pid_x * BLOCK_SIZE + col_idx).to(tl.int64)
69 offset_mask = idx < end_idx
71 # calculate input offset for each tensor separately
72 in_offset0 = itensor_ptr0 + idx
73 in_offset1 = itensor_ptr1 + idx
74 in_offset2 = itensor_ptr2 + idx
75 in_offset3 = itensor_ptr3 + idx
77 # load data from the corresponding tensor
78 out0 = tl.load(in_offset0, mask=offset_mask & mask0, other=0.0)
79 out1 = tl.load(in_offset1, mask=offset_mask & mask1, other=0.0)
80 out2 = tl.load(in_offset2, mask=offset_mask & mask2, other=0.0)
81 out3 = tl.load(in_offset3, mask=offset_mask & mask3, other=0.0)
83 # consolidation result
84 out = out0 + out1 + out2 + out3
86 row_stride_offset = (total_row_offset + base_exc_row_idx) * row_stride.to(tl.int64)
87 out_offset = output_ptr + row_stride_offset + idx
88 tl.store(out_offset, out, mask=offset_mask)
91def vstack(tensors: list):
92 logger.debug("GEMS_ASCEND VSTACK")
94 tensors = torch.atleast_2d(tensors)
95 num_tensors = len(tensors)
96 assert num_tensors > 0
98 # Ensure all tensors are on the same device and have the same dtype
99 device = tensors[0].device
100 dtype = tensors[0].dtype
101 for tensor in tensors:
102 assert (
103 tensor.device == device
104 and tensor.dtype == dtype
105 and tensors[0].shape[1:] == tensor.shape[1:]
106 )
108 c_tensors = [t.contiguous() for t in tensors]
109 # Calculate the output shape
110 total_rows = sum(tensor.shape[0] for tensor in c_tensors)
111 output_shape = list(c_tensors[0].shape)
112 output_shape[0] = total_rows
113 output = torch.empty(output_shape, device=device, dtype=dtype)
114 row_stride = c_tensors[0].stride(0)
116 outer_iters = triton.cdiv(num_tensors, 4)
117 total_row_offset = 0
118 for i in range(outer_iters):
119 max_rows = 1
120 itensors = []
121 exclusive_row = []
122 local_row = []
123 array_row_offset = 0
124 scheduled_num_tensors = 0
125 for j in range(4):
126 tensor_idx = i * 4 + j
127 if tensor_idx < num_tensors:
128 scheduled_num_tensors += 1
129 itensors.append(c_tensors[tensor_idx])
130 local_row.append(c_tensors[tensor_idx].shape[0])
131 exclusive_row.append(array_row_offset)
132 array_row_offset += c_tensors[tensor_idx].shape[0]
133 max_rows = max(max_rows, c_tensors[tensor_idx].shape[0])
134 else:
135 empty_tensor = torch.empty(
136 0, dtype=c_tensors[0].dtype, device=c_tensors[0].device
137 )
138 itensors.append(empty_tensor)
139 local_row.append(local_row[-1])
140 exclusive_row.append(exclusive_row[-1])
141 max_tile_elems = max_rows * row_stride # 最大的tiling size
142 grid = lambda META: (
143 triton.cdiv(max_tile_elems, META["BLOCK_SIZE"]),
144 scheduled_num_tensors,
145 )
146 # Launch the kernel
147 with torch_device_fn.device(c_tensors[0].device):
148 vstack_kernel[grid](
149 itensors[0],
150 itensors[1],
151 itensors[2],
152 itensors[3],
153 output,
154 local_row[0], # tensor[0]的shape(0)
155 local_row[1], # tensor[1]的shape(0)
156 local_row[2], # tensor[2]的shape(0)
157 local_row[3], # tensor[3]的shape(0)
158 exclusive_row[0], # 0
159 exclusive_row[1], # 0 + tensor[0]的shape[0]
160 exclusive_row[2], # 0 + tensor[0]的shape[0] + tensor[1]的shape[0]
161 exclusive_row[
162 3
163 ], # 0 + tensor[0]的shape[0] + tensor[1]的shape[0] + tensor[2]的shape[0]
164 total_row_offset,
165 row_stride, # stride(0)
166 max_tile_elems,
167 )
168 total_row_offset += array_row_offset
169 return output