Coverage for src/flag_gems/ops/stack.py: 60%

86 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-18 02:36 +0800

1import logging 

2from typing import List, Tuple, Union 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8logger = logging.getLogger(__name__) 

9 

10 

11@triton.jit 

12def stack_copy_func_kernel_4( 

13 out_ptr, 

14 in_ptr_a, 

15 in_ptr_b, 

16 in_ptr_c, 

17 in_ptr_d, 

18 dim_size_in_a, 

19 dim_size_in_b, 

20 dim_size_in_c, 

21 dim_size_in_d, 

22 dim_size_out, 

23 dim_prod_post, 

24 dim_offset_a, 

25 dim_offset_b, 

26 dim_offset_c, 

27 dim_offset_d, 

28 total_elements_a, 

29 total_elements_b, 

30 total_elements_c, 

31 total_elements_d, 

32 BLOCK_X: tl.constexpr, 

33): 

34 pid_x = tl.program_id(0) 

35 pid_y = tl.program_id(1) 

36 

37 if pid_y == 0: 

38 in_ptr = in_ptr_a 

39 dim_size_in = dim_size_in_a 

40 dim_offset = dim_offset_a 

41 total_elements = total_elements_a 

42 elif pid_y == 1: 

43 in_ptr = in_ptr_b 

44 dim_size_in = dim_size_in_b 

45 dim_offset = dim_offset_b 

46 total_elements = total_elements_b 

47 elif pid_y == 2: 

48 in_ptr = in_ptr_c 

49 dim_size_in = dim_size_in_c 

50 dim_offset = dim_offset_c 

51 total_elements = total_elements_c 

52 else: 

53 in_ptr = in_ptr_d 

54 dim_size_in = dim_size_in_d 

55 dim_offset = dim_offset_d 

56 total_elements = total_elements_d 

57 

58 block_start = pid_x * BLOCK_X 

59 offsets = tl.arange(0, BLOCK_X) 

60 mask = block_start + offsets < total_elements 

61 

62 idx = block_start + offsets 

63 

64 pre_idx = idx // (dim_size_in * dim_prod_post) 

65 post_idx = idx % dim_prod_post 

66 pre_idx = idx // dim_prod_post 

67 

68 out_idx = ( 

69 pre_idx * dim_size_out * dim_prod_post + dim_offset * dim_prod_post + post_idx 

70 ) 

71 

72 data = tl.load(in_ptr + idx, mask=mask) 

73 tl.store(out_ptr + out_idx, data, mask=mask) 

74 

75 

76def stack( 

77 tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int = 0 

78) -> torch.Tensor: 

79 logger.debug("GEMS STACK") 

80 

81 if len(tensors) == 0: 

82 raise RuntimeError("stack expected a non-empty TensorList") 

83 

84 inp_shapes = [list(_.shape) for _ in tensors] 

85 inp0_shape = inp_shapes[0] 

86 for i, s in enumerate(inp_shapes[1:]): 

87 if (dim < -tensors[i + 1].dim() - 1) or (dim > tensors[i + 1].dim()): 

88 raise IndexError( 

89 "Dimension out of range (expected to be in range of [{}, {}], but got {})".format( 

90 -tensors[i + 1].dim() - 1, tensors[i + 1].dim(), dim 

91 ) 

92 ) 

93 if s != inp0_shape: 

94 raise RuntimeError( 

95 f"stack expects each tensor to be equal size, but got {inp0_shape} at entry 0 and {s} at entry {i + 1}" 

96 ) 

97 

98 if dim < 0: 

99 dim = dim + len(inp0_shape) + 1 

100 

101 # Type promotion: find the common dtype for all tensors 

102 dtypes = [t.dtype for t in tensors] 

103 dtype = dtypes[0] 

104 for dt in dtypes[1:]: 

105 dtype = torch.promote_types(dtype, dt) 

106 # Convert all tensors to the result dtype if needed 

107 tensors = [t.to(dtype) if t.dtype != dtype else t for t in tensors] 

108 device = tensors[0].device 

109 out_shape = inp0_shape[:dim] + [len(tensors)] + inp0_shape[dim:] 

110 out = torch.empty(out_shape, dtype=dtype, device=device) 

111 

112 dim_prod_post = 1 

113 for s in inp0_shape[dim:]: 

114 dim_prod_post *= s 

115 

116 BLOCK = 1024 

117 i = 0 

118 while i < len(tensors): 

119 tensors_in_batch = tensors[i : i + 4] 

120 num_tensors_in_batch = len(tensors_in_batch) 

121 

122 args = [] 

123 total_elements_list = [] 

124 

125 for j in range(4): 

126 if j < num_tensors_in_batch: 

127 tensor = tensors_in_batch[j].contiguous() 

128 total_elements = tensor.numel() 

129 args.extend([tensor, 1, i + j, total_elements]) 

130 total_elements_list.append(total_elements) 

131 else: 

132 args.extend([tensors_in_batch[0], 0, 0, 0]) 

133 total_elements_list.append(0) 

134 

135 dim_size_out = len(tensors) 

136 

137 grid_y = num_tensors_in_batch 

138 max_elements_in_batch = tensors[0].numel() if total_elements_list else 0 

139 grid = (triton.cdiv(max_elements_in_batch, BLOCK), grid_y) 

140 

141 ( 

142 tensor_a, 

143 dim_size_in_a, 

144 dim_offset_a, 

145 total_elements_a, 

146 tensor_b, 

147 dim_size_in_b, 

148 dim_offset_b, 

149 total_elements_b, 

150 tensor_c, 

151 dim_size_in_c, 

152 dim_offset_c, 

153 total_elements_c, 

154 tensor_d, 

155 dim_size_in_d, 

156 dim_offset_d, 

157 total_elements_d, 

158 ) = args 

159 

160 stack_copy_func_kernel_4[grid]( 

161 out, 

162 tensor_a, 

163 tensor_b, 

164 tensor_c, 

165 tensor_d, 

166 dim_size_in_a, 

167 dim_size_in_b, 

168 dim_size_in_c, 

169 dim_size_in_d, 

170 dim_size_out, 

171 dim_prod_post, 

172 dim_offset_a, 

173 dim_offset_b, 

174 dim_offset_c, 

175 dim_offset_d, 

176 total_elements_a, 

177 total_elements_b, 

178 total_elements_c, 

179 total_elements_d, 

180 BLOCK_X=BLOCK, 

181 ) 

182 i += num_tensors_in_batch 

183 

184 return out