Coverage for src/flag_gems/runtime/backend/_ascend/ops/vstack.py: 0%

80 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-21 14:31 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

13 

14 

15@libentry() 

16@triton.autotune( 

17 configs=runtime.get_tuned_config("vstack"), 

18 key=[ 

19 "max_tile_elems", 

20 ], 

21) 

22@triton.jit 

23def vstack_kernel( 

24 itensor_ptr0, 

25 itensor_ptr1, 

26 itensor_ptr2, 

27 itensor_ptr3, 

28 output_ptr, 

29 local_row0, 

30 local_row1, 

31 local_row2, 

32 local_row3, 

33 exc_row_offset0, 

34 exc_row_offset1, 

35 exc_row_offset2, 

36 exc_row_offset3, 

37 total_row_offset, 

38 row_stride, 

39 max_tile_elems, 

40 BLOCK_SIZE: tl.constexpr, 

41): 

42 pid_x = tle.program_id(axis=0) 

43 tensor_idx = tle.program_id(axis=1) 

44 col_idx = tl.arange(0, BLOCK_SIZE) 

45 

46 # create a mask to select a corresponding tensor 

47 mask0 = tensor_idx == 0 

48 mask1 = tensor_idx == 1 

49 mask2 = tensor_idx == 2 

50 mask3 = tensor_idx == 3 

51 

52 # using mask and mathematical operations to select parameters 

53 base_exc_row_idx = ( 

54 mask0 * exc_row_offset0 

55 + mask1 * exc_row_offset1 

56 + mask2 * exc_row_offset2 

57 + mask3 * exc_row_offset3 

58 ) 

59 

60 local_row = ( 

61 mask0 * local_row0 

62 + mask1 * local_row1 

63 + mask2 * local_row2 

64 + mask3 * local_row3 

65 ) 

66 

67 end_idx = local_row * row_stride.to(tl.int64) 

68 idx = (pid_x * BLOCK_SIZE + col_idx).to(tl.int64) 

69 offset_mask = idx < end_idx 

70 

71 # calculate input offset for each tensor separately 

72 in_offset0 = itensor_ptr0 + idx 

73 in_offset1 = itensor_ptr1 + idx 

74 in_offset2 = itensor_ptr2 + idx 

75 in_offset3 = itensor_ptr3 + idx 

76 

77 # load data from the corresponding tensor 

78 out0 = tl.load(in_offset0, mask=offset_mask & mask0, other=0.0) 

79 out1 = tl.load(in_offset1, mask=offset_mask & mask1, other=0.0) 

80 out2 = tl.load(in_offset2, mask=offset_mask & mask2, other=0.0) 

81 out3 = tl.load(in_offset3, mask=offset_mask & mask3, other=0.0) 

82 

83 # consolidation result 

84 out = out0 + out1 + out2 + out3 

85 

86 row_stride_offset = (total_row_offset + base_exc_row_idx) * row_stride.to(tl.int64) 

87 out_offset = output_ptr + row_stride_offset + idx 

88 tl.store(out_offset, out, mask=offset_mask) 

89 

90 

91def vstack(tensors: list): 

92 logger.debug("GEMS_ASCEND VSTACK") 

93 

94 tensors = torch.atleast_2d(tensors) 

95 num_tensors = len(tensors) 

96 assert num_tensors > 0 

97 

98 # Ensure all tensors are on the same device and have the same dtype 

99 device = tensors[0].device 

100 dtype = tensors[0].dtype 

101 for tensor in tensors: 

102 assert ( 

103 tensor.device == device 

104 and tensor.dtype == dtype 

105 and tensors[0].shape[1:] == tensor.shape[1:] 

106 ) 

107 

108 c_tensors = [t.contiguous() for t in tensors] 

109 # Calculate the output shape 

110 total_rows = sum(tensor.shape[0] for tensor in c_tensors) 

111 output_shape = list(c_tensors[0].shape) 

112 output_shape[0] = total_rows 

113 output = torch.empty(output_shape, device=device, dtype=dtype) 

114 row_stride = c_tensors[0].stride(0) 

115 

116 outer_iters = triton.cdiv(num_tensors, 4) 

117 total_row_offset = 0 

118 for i in range(outer_iters): 

119 max_rows = 1 

120 itensors = [] 

121 exclusive_row = [] 

122 local_row = [] 

123 array_row_offset = 0 

124 scheduled_num_tensors = 0 

125 for j in range(4): 

126 tensor_idx = i * 4 + j 

127 if tensor_idx < num_tensors: 

128 scheduled_num_tensors += 1 

129 itensors.append(c_tensors[tensor_idx]) 

130 local_row.append(c_tensors[tensor_idx].shape[0]) 

131 exclusive_row.append(array_row_offset) 

132 array_row_offset += c_tensors[tensor_idx].shape[0] 

133 max_rows = max(max_rows, c_tensors[tensor_idx].shape[0]) 

134 else: 

135 empty_tensor = torch.empty( 

136 0, dtype=c_tensors[0].dtype, device=c_tensors[0].device 

137 ) 

138 itensors.append(empty_tensor) 

139 local_row.append(local_row[-1]) 

140 exclusive_row.append(exclusive_row[-1]) 

141 max_tile_elems = max_rows * row_stride # 最大的tiling size 

142 grid = lambda META: ( 

143 triton.cdiv(max_tile_elems, META["BLOCK_SIZE"]), 

144 scheduled_num_tensors, 

145 ) 

146 # Launch the kernel 

147 with torch_device_fn.device(c_tensors[0].device): 

148 vstack_kernel[grid]( 

149 itensors[0], 

150 itensors[1], 

151 itensors[2], 

152 itensors[3], 

153 output, 

154 local_row[0], # tensor[0]的shape(0) 

155 local_row[1], # tensor[1]的shape(0) 

156 local_row[2], # tensor[2]的shape(0) 

157 local_row[3], # tensor[3]的shape(0) 

158 exclusive_row[0], # 0 

159 exclusive_row[1], # 0 + tensor[0]的shape[0] 

160 exclusive_row[2], # 0 + tensor[0]的shape[0] + tensor[1]的shape[0] 

161 exclusive_row[ 

162 3 

163 ], # 0 + tensor[0]的shape[0] + tensor[1]的shape[0] + tensor[2]的shape[0] 

164 total_row_offset, 

165 row_stride, # stride(0) 

166 max_tile_elems, 

167 ) 

168 total_row_offset += array_row_offset 

169 return output