Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/vstack.py: 0%

77 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7# from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13 

14 

15def heur_block_size(args): 

16 return triton.next_power_of_2( 

17 triton.cdiv(args["max_tile_elems"], 12) 

18 ) # cluster_num 

19 

20 

21@libentry() 

22# @triton.autotune( 

23# configs=runtime.get_tuned_config("vstack"), 

24# key=[ 

25# "max_tile_elems", 

26# ], 

27# ) 

28@triton.heuristics( 

29 values={ 

30 "BLOCK_SIZE": heur_block_size, 

31 }, 

32) 

33@triton.jit 

34def vstack_kernel( 

35 itensor_ptr0, 

36 itensor_ptr1, 

37 itensor_ptr2, 

38 itensor_ptr3, 

39 output_ptr, 

40 local_row0, 

41 local_row1, 

42 local_row2, 

43 local_row3, 

44 exc_row_offset0, 

45 exc_row_offset1, 

46 exc_row_offset2, 

47 exc_row_offset3, 

48 total_row_offset, 

49 row_stride, 

50 max_tile_elems, 

51 BLOCK_SIZE: tl.constexpr, 

52): 

53 pid_x = tle.program_id(axis=0) 

54 tensor_idx = tle.program_id(axis=1) 

55 col_idx = tl.arange(0, BLOCK_SIZE) 

56 

57 intensor_ptr = tl.where(tensor_idx == 0, itensor_ptr0, itensor_ptr1) 

58 intensor_ptr = tl.where(tensor_idx == 2, itensor_ptr2, intensor_ptr) 

59 intensor_ptr = tl.where(tensor_idx == 3, itensor_ptr3, intensor_ptr) 

60 base_exc_row_idx = tl.where(tensor_idx == 0, exc_row_offset0, exc_row_offset1) 

61 base_exc_row_idx = tl.where(tensor_idx == 2, exc_row_offset2, base_exc_row_idx) 

62 base_exc_row_idx = tl.where(tensor_idx == 3, exc_row_offset3, base_exc_row_idx) 

63 local_row = tl.where(tensor_idx == 0, local_row0, local_row1) 

64 local_row = tl.where(tensor_idx == 2, local_row2, local_row) 

65 local_row = tl.where(tensor_idx == 3, local_row3, local_row) 

66 

67 end_idx = local_row * row_stride.to(tl.int64) 

68 idx = (pid_x * BLOCK_SIZE + col_idx).to(tl.int64) 

69 offset_mask = idx < end_idx 

70 in_offset = intensor_ptr + idx 

71 row_stride_offset = (total_row_offset + base_exc_row_idx) * row_stride.to(tl.int64) 

72 out_offset = output_ptr + row_stride_offset + idx 

73 out = tl.load(in_offset, mask=offset_mask) 

74 tl.store(out_offset, out, mask=offset_mask) 

75 

76 

77def vstack(tensors: list): 

78 logger.debug("GEMS VSTACK") 

79 

80 tensors = torch.atleast_2d(tensors) 

81 num_tensors = len(tensors) 

82 assert num_tensors > 0 

83 

84 # Ensure all tensors are on the same device and have the same dtype 

85 device = tensors[0].device 

86 dtype = tensors[0].dtype 

87 for tensor in tensors: 

88 assert ( 

89 tensor.device == device 

90 and tensor.dtype == dtype 

91 and tensors[0].shape[1:] == tensor.shape[1:] 

92 ) 

93 

94 c_tensors = [t.contiguous() for t in tensors] 

95 # Calculate the output shape 

96 total_rows = sum(tensor.shape[0] for tensor in c_tensors) 

97 output_shape = list(c_tensors[0].shape) 

98 output_shape[0] = total_rows 

99 output = torch.empty(output_shape, device=device, dtype=dtype) 

100 row_stride = c_tensors[0].stride(0) 

101 

102 outer_iters = triton.cdiv(num_tensors, 4) 

103 total_row_offset = 0 

104 for i in range(outer_iters): 

105 max_rows = 1 

106 itensors = [] 

107 exclusive_row = [] 

108 local_row = [] 

109 array_row_offset = 0 

110 scheduled_num_tensors = 0 

111 for j in range(4): 

112 tensor_idx = i * 4 + j 

113 if tensor_idx < num_tensors: 

114 scheduled_num_tensors += 1 

115 itensors.append(c_tensors[tensor_idx]) 

116 local_row.append(c_tensors[tensor_idx].shape[0]) 

117 exclusive_row.append(array_row_offset) 

118 array_row_offset += c_tensors[tensor_idx].shape[0] 

119 max_rows = max(max_rows, c_tensors[tensor_idx].shape[0]) 

120 else: 

121 empty_tensor = torch.empty( 

122 0, dtype=c_tensors[0].dtype, device=c_tensors[0].device 

123 ) 

124 itensors.append(empty_tensor) 

125 local_row.append(local_row[-1]) 

126 exclusive_row.append(exclusive_row[-1]) 

127 max_tile_elems = max_rows * row_stride 

128 grid = lambda META: ( 

129 triton.cdiv(max_tile_elems, META["BLOCK_SIZE"]), 

130 scheduled_num_tensors, 

131 ) 

132 # Launch the kernel 

133 with torch_device_fn.device(c_tensors[0].device): 

134 vstack_kernel[grid]( 

135 itensors[0], 

136 itensors[1], 

137 itensors[2], 

138 itensors[3], 

139 output, 

140 local_row[0], 

141 local_row[1], 

142 local_row[2], 

143 local_row[3], 

144 exclusive_row[0], 

145 exclusive_row[1], 

146 exclusive_row[2], 

147 exclusive_row[3], 

148 total_row_offset, 

149 row_stride, 

150 max_tile_elems, 

151 ) 

152 total_row_offset += array_row_offset 

153 return output