Coverage for src/flag_gems/runtime/backend/_cambricon/ops/kron.py: 0%

121 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import triton_lang_extension as tle 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14def prepare_tensor_for_kron(tensor_a, tensor_b): 

15 a_shape = list(tensor_a.shape) 

16 b_shape = list(tensor_b.shape) 

17 

18 if tensor_a.numel() == 0 or tensor_b.numel() == 0: 

19 if not a_shape: 

20 a_shape = [0] 

21 if not b_shape: 

22 b_shape = [0] 

23 

24 if len(a_shape) > len(b_shape): 

25 b_shape = [1] * (len(a_shape) - len(b_shape)) + b_shape 

26 elif len(b_shape) > len(a_shape): 

27 a_shape = [1] * (len(b_shape) - len(a_shape)) + a_shape 

28 

29 out_shape = tuple(a * b for a, b in zip(a_shape, b_shape)) 

30 return tensor_a.reshape(*a_shape), tensor_b.reshape(*b_shape), out_shape 

31 

32 if len(a_shape) < 2: 

33 a_shape = [1] * (2 - len(a_shape)) + a_shape 

34 if len(b_shape) < 2: 

35 b_shape = [1] * (2 - len(b_shape)) + b_shape 

36 

37 if len(a_shape) > len(b_shape): 

38 b_shape = [1] * (len(a_shape) - len(b_shape)) + b_shape 

39 elif len(b_shape) > len(a_shape): 

40 a_shape = [1] * (len(b_shape) - len(a_shape)) + a_shape 

41 

42 out_shape = tuple(a * b for a, b in zip(a_shape, b_shape)) 

43 return tensor_a.reshape(*a_shape), tensor_b.reshape(*b_shape), out_shape 

44 

45 

46def calculate_indices(batch_idx, shape_a, shape_b): 

47 a_batch_dims = shape_a[:-2] or (1,) 

48 b_batch_dims = shape_b[:-2] or (1,) 

49 out_batch_dims = tuple(a * b for a, b in zip(a_batch_dims, b_batch_dims)) 

50 

51 out_indices = [] 

52 remaining = batch_idx 

53 for dim_size in out_batch_dims[::-1]: 

54 out_indices.insert(0, remaining % dim_size) 

55 remaining //= dim_size 

56 

57 a_idx = b_idx = 0 

58 for out_idx, (a_dim, b_dim) in zip(out_indices, zip(a_batch_dims, b_batch_dims)): 

59 a_idx = a_idx * a_dim + (out_idx // b_dim) 

60 b_idx = b_idx * b_dim + (out_idx % b_dim) 

61 

62 return a_idx, b_idx 

63 

64 

65@triton.autotune( 

66 configs=[ 

67 triton.Config( 

68 {"BLOCK_M": 512, "BLOCK_N": 512, "BLOCK_TILE_M": 128, "BLOCK_TILE_N": 64} 

69 ), 

70 triton.Config( 

71 {"BLOCK_M": 512, "BLOCK_N": 2048, "BLOCK_TILE_M": 64, "BLOCK_TILE_N": 64} 

72 ), 

73 triton.Config( 

74 {"BLOCK_M": 512, "BLOCK_N": 8192, "BLOCK_TILE_M": 128, "BLOCK_TILE_N": 64} 

75 ), 

76 triton.Config( 

77 {"BLOCK_M": 512, "BLOCK_N": 32768, "BLOCK_TILE_M": 128, "BLOCK_TILE_N": 64} 

78 ), 

79 ], 

80 key=[ 

81 "M", 

82 "N", 

83 ], 

84) 

85@triton.jit 

86def kron_kernel( 

87 a_ptr, 

88 b_ptr, 

89 c_ptr, 

90 map_ptr, 

91 batch_size: tl.int64, 

92 M: tl.int64, 

93 N: tl.int64, 

94 M1: tl.int64, 

95 M2: tl.int64, 

96 N1: tl.int64, 

97 N2: tl.int64, 

98 a_stride_0: tl.int64, 

99 a_stride_1: tl.int64, 

100 b_stride_0: tl.int64, 

101 b_stride_1: tl.int64, 

102 c_stride_0: tl.int64, 

103 c_stride_1: tl.int64, 

104 a_batch_stride: tl.int64, 

105 b_batch_stride: tl.int64, 

106 c_batch_stride: tl.int64, 

107 BLOCK_M: tl.constexpr, 

108 BLOCK_N: tl.constexpr, 

109 BLOCK_TILE_M: tl.constexpr, 

110 BLOCK_TILE_N: tl.constexpr, 

111): 

112 pid = tle.program_id(0) 

113 num_blocks_n = tl.cdiv(N, BLOCK_N) 

114 num_blocks_m = tl.cdiv(M, BLOCK_M) 

115 num_blocks_per_batch = num_blocks_m * num_blocks_n 

116 

117 batch_id = pid // num_blocks_per_batch 

118 local_pid = pid % num_blocks_per_batch 

119 block_m = local_pid // num_blocks_n 

120 block_n = local_pid % num_blocks_n 

121 

122 offset = batch_id * 2 

123 is_valid_batch = batch_id < batch_size 

124 a_batch_idx = tl.load(map_ptr + offset, mask=is_valid_batch) 

125 b_batch_idx = tl.load(map_ptr + offset + 1, mask=is_valid_batch) 

126 

127 num_tiles_m = tl.cdiv(BLOCK_M, BLOCK_TILE_M) 

128 num_tiles_n = tl.cdiv(BLOCK_N, BLOCK_TILE_N) 

129 

130 for tile_m in range(num_tiles_m): 

131 for tile_n in range(num_tiles_n): 

132 tile_offset_m = tile_m * BLOCK_TILE_M 

133 tile_offset_n = tile_n * BLOCK_TILE_N 

134 

135 current_offs_m = ( 

136 block_m * BLOCK_M + tile_offset_m + tl.arange(0, BLOCK_TILE_M) 

137 ) 

138 current_offs_n = ( 

139 block_n * BLOCK_N + tile_offset_n + tl.arange(0, BLOCK_TILE_N) 

140 ) 

141 

142 tile_mask = ( 

143 (current_offs_m[:, None] < M) 

144 & (current_offs_n[None, :] < N) 

145 & is_valid_batch 

146 ) 

147 

148 a_row = current_offs_m[:, None] // M2 

149 a_col = current_offs_n[None, :] // N2 

150 b_row = current_offs_m[:, None] % M2 

151 b_col = current_offs_n[None, :] % N2 

152 

153 a_idx = ( 

154 a_batch_idx * a_batch_stride + a_row * a_stride_0 + a_col * a_stride_1 

155 ) 

156 b_idx = ( 

157 b_batch_idx * b_batch_stride + b_row * b_stride_0 + b_col * b_stride_1 

158 ) 

159 

160 a = tl.load(a_ptr + a_idx, mask=tile_mask) 

161 b = tl.load(b_ptr + b_idx, mask=tile_mask) 

162 c = a * b 

163 

164 c_idx = ( 

165 batch_id * c_batch_stride 

166 + current_offs_m[:, None] * c_stride_0 

167 + current_offs_n[None, :] * c_stride_1 

168 ) 

169 

170 tl.store(c_ptr + c_idx, c, mask=tile_mask) 

171 

172 

173def kron(A, B): 

174 logger.debug("GEMS_CAMBRICON KRON") 

175 if A.dim() == 0 and B.dim() == 0: 

176 return A * B 

177 

178 if A.numel() == 0 or B.numel() == 0: 

179 A_prepared, B_prepared, out_shape = prepare_tensor_for_kron(A, B) 

180 output_dtype = torch.promote_types(A.dtype, B.dtype) 

181 return torch.empty(out_shape, device=A.device, dtype=output_dtype) 

182 

183 if A.dim() == 0: 

184 return A.unsqueeze(0) * B 

185 if B.dim() == 0: 

186 return A * B.unsqueeze(0) 

187 

188 A_prepared, B_prepared, out_shape = prepare_tensor_for_kron(A, B) 

189 M1, N1 = A_prepared.shape[-2:] 

190 M2, N2 = B_prepared.shape[-2:] 

191 M, N = M1 * M2, N1 * N2 

192 

193 batch_size = math.prod(out_shape[:-2]) if out_shape[:-2] else 1 

194 

195 output_dtype = torch.promote_types(A.dtype, B.dtype) 

196 C = torch.empty(out_shape, device=A.device, dtype=output_dtype) 

197 

198 C_reshaped = C.view(-1, M, N) 

199 A_view = A_prepared.reshape(-1, M1, N1) 

200 B_view = B_prepared.reshape(-1, M2, N2) 

201 

202 if not A_view.is_contiguous(): 

203 A_view = A_view.contiguous() 

204 if not B_view.is_contiguous(): 

205 B_view = B_view.contiguous() 

206 

207 batch_indices = torch.empty(batch_size * 2, device=A.device, dtype=torch.int64) 

208 for i in range(batch_size): 

209 a_idx, b_idx = calculate_indices(i, A_prepared.shape, B_prepared.shape) 

210 batch_indices[i * 2] = a_idx 

211 batch_indices[i * 2 + 1] = b_idx 

212 

213 a_batch_stride = M1 * N1 

214 b_batch_stride = M2 * N2 

215 c_batch_stride = M * N 

216 

217 with torch_device_fn.device(A.device): 

218 grid = lambda meta: ( 

219 batch_size 

220 * triton.cdiv(M, meta["BLOCK_M"]) 

221 * triton.cdiv(N, meta["BLOCK_N"]), 

222 ) 

223 

224 kron_kernel[grid]( 

225 A_view, 

226 B_view, 

227 C_reshaped, 

228 batch_indices, 

229 batch_size, 

230 M, 

231 N, 

232 M1, 

233 M2, 

234 N1, 

235 N2, 

236 A_view.stride(1), 

237 A_view.stride(2), 

238 B_view.stride(1), 

239 B_view.stride(2), 

240 C_reshaped.stride(1), 

241 C_reshaped.stride(2), 

242 a_batch_stride, 

243 b_batch_stride, 

244 c_batch_stride, 

245 ) 

246 

247 if A.dim() <= 1 and B.dim() <= 1: 

248 return C.reshape(-1) 

249 

250 return C