Coverage for src/flag_gems/ops/vstack.py: 74%

76 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15@libentry() 

16@triton.autotune( 

17 configs=runtime.get_tuned_config("vstack"), 

18 key=[ 

19 "max_tile_elems", 

20 ], 

21) 

22@triton.jit 

23def vstack_kernel( 

24 itensor_ptr0, 

25 itensor_ptr1, 

26 itensor_ptr2, 

27 itensor_ptr3, 

28 output_ptr, 

29 local_row0, 

30 local_row1, 

31 local_row2, 

32 local_row3, 

33 exc_row_offset0, 

34 exc_row_offset1, 

35 exc_row_offset2, 

36 exc_row_offset3, 

37 total_row_offset, 

38 row_stride, 

39 max_tile_elems, 

40 BLOCK_SIZE: tl.constexpr, 

41): 

42 pid_x = tle.program_id(axis=0) 

43 tensor_idx = tle.program_id(axis=1) 

44 col_idx = tl.arange(0, BLOCK_SIZE) 

45 

46 intensor_ptr = tl.where(tensor_idx == 0, itensor_ptr0, itensor_ptr1) 

47 intensor_ptr = tl.where(tensor_idx == 2, itensor_ptr2, intensor_ptr) 

48 intensor_ptr = tl.where(tensor_idx == 3, itensor_ptr3, intensor_ptr) 

49 base_exc_row_idx = tl.where(tensor_idx == 0, exc_row_offset0, exc_row_offset1) 

50 base_exc_row_idx = tl.where(tensor_idx == 2, exc_row_offset2, base_exc_row_idx) 

51 base_exc_row_idx = tl.where(tensor_idx == 3, exc_row_offset3, base_exc_row_idx) 

52 local_row = tl.where(tensor_idx == 0, local_row0, local_row1) 

53 local_row = tl.where(tensor_idx == 2, local_row2, local_row) 

54 local_row = tl.where(tensor_idx == 3, local_row3, local_row) 

55 

56 end_idx = local_row * row_stride.to(tl.int64) 

57 idx = (pid_x * BLOCK_SIZE + col_idx).to(tl.int64) 

58 offset_mask = idx < end_idx 

59 in_offset = intensor_ptr + idx 

60 row_stride_offset = (total_row_offset + base_exc_row_idx) * row_stride.to(tl.int64) 

61 out_offset = output_ptr + row_stride_offset + idx 

62 out = tl.load(in_offset, mask=offset_mask) 

63 tl.store(out_offset, out, mask=offset_mask) 

64 

65 

66def vstack(tensors: list): 

67 logger.debug("GEMS VSTACK") 

68 

69 tensors = torch.atleast_2d(tensors) 

70 num_tensors = len(tensors) 

71 assert num_tensors > 0 

72 

73 # Ensure all tensors are on the same device and have the same dtype 

74 device = tensors[0].device 

75 dtype = tensors[0].dtype 

76 for tensor in tensors: 

77 assert ( 

78 tensor.device == device 

79 and tensor.dtype == dtype 

80 and tensors[0].shape[1:] == tensor.shape[1:] 

81 ) 

82 

83 c_tensors = [t.contiguous() for t in tensors] 

84 # Calculate the output shape 

85 total_rows = sum(tensor.shape[0] for tensor in c_tensors) 

86 output_shape = list(c_tensors[0].shape) 

87 output_shape[0] = total_rows 

88 output = torch.empty(output_shape, device=device, dtype=dtype) 

89 row_stride = c_tensors[0].stride(0) 

90 

91 outer_iters = triton.cdiv(num_tensors, 4) 

92 total_row_offset = 0 

93 for i in range(outer_iters): 

94 max_rows = 1 

95 itensors = [] 

96 exclusive_row = [] 

97 local_row = [] 

98 array_row_offset = 0 

99 scheduled_num_tensors = 0 

100 for j in range(4): 

101 tensor_idx = i * 4 + j 

102 if tensor_idx < num_tensors: 

103 scheduled_num_tensors += 1 

104 itensors.append(c_tensors[tensor_idx]) 

105 local_row.append(c_tensors[tensor_idx].shape[0]) 

106 exclusive_row.append(array_row_offset) 

107 array_row_offset += c_tensors[tensor_idx].shape[0] 

108 max_rows = max(max_rows, c_tensors[tensor_idx].shape[0]) 

109 else: 

110 empty_tensor = torch.empty( 

111 0, dtype=c_tensors[0].dtype, device=c_tensors[0].device 

112 ) 

113 itensors.append(empty_tensor) 

114 local_row.append(local_row[-1]) 

115 exclusive_row.append(exclusive_row[-1]) 

116 max_tile_elems = max_rows * row_stride 

117 grid = lambda META: ( 

118 triton.cdiv(max_tile_elems, META["BLOCK_SIZE"]), 

119 scheduled_num_tensors, 

120 ) 

121 # Launch the kernel 

122 with torch_device_fn.device(c_tensors[0].device): 

123 vstack_kernel[grid]( 

124 itensors[0], 

125 itensors[1], 

126 itensors[2], 

127 itensors[3], 

128 output, 

129 local_row[0], 

130 local_row[1], 

131 local_row[2], 

132 local_row[3], 

133 exclusive_row[0], 

134 exclusive_row[1], 

135 exclusive_row[2], 

136 exclusive_row[3], 

137 total_row_offset, 

138 row_stride, 

139 max_tile_elems, 

140 ) 

141 total_row_offset += array_row_offset 

142 return output