Coverage for src/flag_gems/ops/hstack.py: 68%

106 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-22 16:54 +0800

1import logging 

2from typing import List, Tuple, Union 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8logger = logging.getLogger(__name__) 

9 

10 

11@triton.jit 

12def hstack_copy_func_kernel_4( 

13 out_ptr, 

14 in_ptr_a, 

15 in_ptr_b, 

16 in_ptr_c, 

17 in_ptr_d, 

18 dim_size_in_a, 

19 dim_size_in_b, 

20 dim_size_in_c, 

21 dim_size_in_d, 

22 dim_size_out, 

23 dim_prod_post, 

24 dim_offset_a, 

25 dim_offset_b, 

26 dim_offset_c, 

27 dim_offset_d, 

28 total_elements_a, 

29 total_elements_b, 

30 total_elements_c, 

31 total_elements_d, 

32 BLOCK_X: tl.constexpr, 

33): 

34 pid_x = tl.program_id(0) 

35 pid_y = tl.program_id(1) 

36 

37 if pid_y == 0: 

38 in_ptr = in_ptr_a 

39 dim_size_in = dim_size_in_a 

40 dim_offset = dim_offset_a 

41 total_elements = total_elements_a 

42 elif pid_y == 1: 

43 in_ptr = in_ptr_b 

44 dim_size_in = dim_size_in_b 

45 dim_offset = dim_offset_b 

46 total_elements = total_elements_b 

47 elif pid_y == 2: 

48 in_ptr = in_ptr_c 

49 dim_size_in = dim_size_in_c 

50 dim_offset = dim_offset_c 

51 total_elements = total_elements_c 

52 else: 

53 in_ptr = in_ptr_d 

54 dim_size_in = dim_size_in_d 

55 dim_offset = dim_offset_d 

56 total_elements = total_elements_d 

57 

58 block_start = pid_x * BLOCK_X 

59 offsets = tl.arange(0, BLOCK_X) 

60 mask = block_start + offsets < total_elements 

61 

62 idx = block_start + offsets 

63 

64 pre_idx = idx // (dim_size_in * dim_prod_post) 

65 dim_idx = (idx // dim_prod_post) % dim_size_in 

66 post_idx = idx % dim_prod_post 

67 

68 out_idx = ( 

69 pre_idx * dim_size_out * dim_prod_post 

70 + (dim_idx + dim_offset) * dim_prod_post 

71 + post_idx 

72 ) 

73 

74 data = tl.load(in_ptr + idx, mask=mask) 

75 tl.store(out_ptr + out_idx, data, mask=mask) 

76 

77 

78def hstack( 

79 tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]] 

80) -> torch.Tensor: 

81 logger.debug("GEMS HSTACK") 

82 

83 if len(tensors) == 0: 

84 raise RuntimeError("hstack expected a non-empty TensorList") 

85 

86 if tensors[0].ndim == 0: 

87 tensors[0] = tensors[0].view(1) 

88 inp0_shape = tensors[0].shape 

89 out_shape = list(inp0_shape) 

90 inp_shapes = [inp0_shape] 

91 

92 if len(inp0_shape) == 1: 

93 dim = 0 

94 else: 

95 dim = 1 

96 

97 for tensor_num, tensor in enumerate(tensors[1:]): 

98 if tensor.ndim == 0: 

99 tensor = tensor.view(1) 

100 if tensor.ndim != tensors[0].ndim: 

101 raise RuntimeError( 

102 f"Tensors must have same number of dimensions: got {tensors[0].ndim} and {tensor.ndim}" 

103 ) 

104 

105 inp_shape = tensor.shape 

106 inp_shapes.append(inp_shape) 

107 

108 for i in range(len(inp_shape)): 

109 if i != dim and inp_shape[i] != inp0_shape[i]: 

110 raise RuntimeError( 

111 f"Sizes of tensors must match except in dimension {dim}. \ 

112 Expected size {inp0_shape[i]} but got size {inp_shape[i]} \ 

113 for tensor number {tensor_num + 1} in the list." 

114 ) 

115 

116 inp_shapes = [list(_.shape) for _ in tensors] 

117 inp0_shape = inp_shapes[0] 

118 

119 # Type promotion: find the common dtype for all tensors 

120 dtypes = [t.dtype for t in tensors] 

121 dtype = dtypes[0] 

122 for dt in dtypes[1:]: 

123 dtype = torch.promote_types(dtype, dt) 

124 # Convert all tensors to the common dtype if needed 

125 tensors = [t.to(dtype) if t.dtype != dtype else t for t in tensors] 

126 device = tensors[0].device 

127 out_shape[dim] = sum(s[dim] for s in inp_shapes) 

128 out = torch.empty(out_shape, dtype=dtype, device=device) 

129 

130 dim_prod_post = 1 

131 for s in inp0_shape[dim:]: 

132 dim_prod_post *= s 

133 BLOCK = 1024 

134 dim_offset = 0 

135 i = 0 

136 while i < len(tensors): 

137 tensors_in_batch = tensors[i : i + 4] 

138 num_tensors_in_batch = len(tensors_in_batch) 

139 

140 args = [] 

141 total_elements_list = [] 

142 current_dim_offset = dim_offset 

143 

144 for j in range(4): 

145 if j < num_tensors_in_batch: 

146 tensor = tensors_in_batch[j].contiguous() 

147 shape = tensor.shape 

148 total_elements = tensor.numel() 

149 dim_size_in = shape[dim] 

150 

151 args.extend([tensor, dim_size_in, current_dim_offset, total_elements]) 

152 total_elements_list.append(total_elements) 

153 current_dim_offset += dim_size_in 

154 else: 

155 # Add placeholders for unused tensor slots 

156 args.extend([tensors_in_batch[0], 0, 0, 0]) 

157 total_elements_list.append(0) 

158 

159 dim_size_out = out_shape[dim] 

160 dim_prod_post = 1 

161 for d in range(dim + 1, tensors[0].ndim): 

162 dim_prod_post *= tensors[0].shape[d] 

163 

164 grid_y = num_tensors_in_batch 

165 max_elements_in_batch = max(total_elements_list) if total_elements_list else 0 

166 grid = (triton.cdiv(max_elements_in_batch, BLOCK), grid_y) 

167 

168 ( 

169 tensor_a, 

170 dim_size_in_a, 

171 dim_offset_a, 

172 total_elements_a, 

173 tensor_b, 

174 dim_size_in_b, 

175 dim_offset_b, 

176 total_elements_b, 

177 tensor_c, 

178 dim_size_in_c, 

179 dim_offset_c, 

180 total_elements_c, 

181 tensor_d, 

182 dim_size_in_d, 

183 dim_offset_d, 

184 total_elements_d, 

185 ) = args 

186 

187 hstack_copy_func_kernel_4[grid]( 

188 out, 

189 tensor_a, 

190 tensor_b, 

191 tensor_c, 

192 tensor_d, 

193 dim_size_in_a, 

194 dim_size_in_b, 

195 dim_size_in_c, 

196 dim_size_in_d, 

197 dim_size_out, 

198 dim_prod_post, 

199 dim_offset_a, 

200 dim_offset_b, 

201 dim_offset_c, 

202 dim_offset_d, 

203 total_elements_a, 

204 total_elements_b, 

205 total_elements_c, 

206 total_elements_d, 

207 BLOCK_X=BLOCK, 

208 ) 

209 

210 dim_offset = current_dim_offset 

211 i += num_tensors_in_batch 

212 

213 return out