Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/kron.py: 0%

117 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-09 01:57 +0800

1import math 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7# from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import triton_lang_extension as tle 

10 

11 

12def prepare_tensor_for_kron(tensor_a, tensor_b): 

13 a_shape = list(tensor_a.shape) 

14 b_shape = list(tensor_b.shape) 

15 

16 if tensor_a.numel() == 0 or tensor_b.numel() == 0: 

17 if not a_shape: 

18 a_shape = [0] 

19 if not b_shape: 

20 b_shape = [0] 

21 

22 if len(a_shape) > len(b_shape): 

23 b_shape = [1] * (len(a_shape) - len(b_shape)) + b_shape 

24 elif len(b_shape) > len(a_shape): 

25 a_shape = [1] * (len(b_shape) - len(a_shape)) + a_shape 

26 

27 out_shape = tuple(a * b for a, b in zip(a_shape, b_shape)) 

28 return tensor_a.reshape(*a_shape), tensor_b.reshape(*b_shape), out_shape 

29 

30 if len(a_shape) < 2: 

31 a_shape = [1] * (2 - len(a_shape)) + a_shape 

32 if len(b_shape) < 2: 

33 b_shape = [1] * (2 - len(b_shape)) + b_shape 

34 

35 if len(a_shape) > len(b_shape): 

36 b_shape = [1] * (len(a_shape) - len(b_shape)) + b_shape 

37 elif len(b_shape) > len(a_shape): 

38 a_shape = [1] * (len(b_shape) - len(a_shape)) + a_shape 

39 

40 out_shape = tuple(a * b for a, b in zip(a_shape, b_shape)) 

41 return tensor_a.reshape(*a_shape), tensor_b.reshape(*b_shape), out_shape 

42 

43 

44def calculate_indices(batch_idx, shape_a, shape_b): 

45 a_batch_dims = shape_a[:-2] or (1,) 

46 b_batch_dims = shape_b[:-2] or (1,) 

47 out_batch_dims = tuple(a * b for a, b in zip(a_batch_dims, b_batch_dims)) 

48 

49 out_indices = [] 

50 remaining = batch_idx 

51 for dim_size in out_batch_dims[::-1]: 

52 out_indices.insert(0, remaining % dim_size) 

53 remaining //= dim_size 

54 

55 a_idx = b_idx = 0 

56 for out_idx, (a_dim, b_dim) in zip(out_indices, zip(a_batch_dims, b_batch_dims)): 

57 a_idx = a_idx * a_dim + (out_idx // b_dim) 

58 b_idx = b_idx * b_dim + (out_idx % b_dim) 

59 

60 return a_idx, b_idx 

61 

62 

63def heur_block_n(args): 

64 import builtins 

65 

66 return builtins.min(args["N"], 8192) 

67 

68 

69def heur_block_m(args): 

70 return triton.next_power_of_2(triton.cdiv(args["M"], 12)) 

71 

72 

73# @triton.autotune(configs=runtime.get_tuned_config("kron"), key=["M", "N"]) 

74@triton.heuristics( 

75 { 

76 "BLOCK_M": heur_block_m, 

77 "BLOCK_N": heur_block_n, 

78 } 

79) 

80@triton.jit 

81def kron_kernel( 

82 a_ptr, 

83 b_ptr, 

84 c_ptr, 

85 map_ptr, 

86 batch_size: tl.int64, 

87 M: tl.int64, 

88 N: tl.int64, 

89 M1: tl.int64, 

90 M2: tl.int64, 

91 N1: tl.int64, 

92 N2: tl.int64, 

93 a_stride_0: tl.int64, 

94 a_stride_1: tl.int64, 

95 b_stride_0: tl.int64, 

96 b_stride_1: tl.int64, 

97 c_stride_0: tl.int64, 

98 c_stride_1: tl.int64, 

99 a_batch_stride: tl.int64, 

100 b_batch_stride: tl.int64, 

101 c_batch_stride: tl.int64, 

102 BLOCK_M: tl.constexpr, 

103 BLOCK_N: tl.constexpr, 

104): 

105 pid = tle.program_id(0) 

106 num_blocks_n = tl.cdiv(N, BLOCK_N) 

107 num_blocks_m = tl.cdiv(M, BLOCK_M) 

108 num_blocks_per_batch = num_blocks_m * num_blocks_n 

109 

110 batch_id = pid // num_blocks_per_batch 

111 local_pid = pid % num_blocks_per_batch 

112 block_m = local_pid // num_blocks_n 

113 block_n = local_pid % num_blocks_n 

114 

115 offs_m = block_m * BLOCK_M + tl.arange(0, BLOCK_M) 

116 offs_n = block_n * BLOCK_N + tl.arange(0, BLOCK_N) 

117 

118 mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) & (batch_id < batch_size) 

119 

120 offset = batch_id * 2 

121 is_valid = batch_id < batch_size 

122 a_batch_idx = tl.load(map_ptr + offset, mask=is_valid) 

123 b_batch_idx = tl.load(map_ptr + offset + 1, mask=is_valid) 

124 

125 a_row = offs_m[:, None] // M2 

126 a_col = offs_n[None, :] // N2 

127 b_row = offs_m[:, None] % M2 

128 b_col = offs_n[None, :] % N2 

129 

130 a_idx = a_batch_idx * a_batch_stride + a_row * a_stride_0 + a_col * a_stride_1 

131 b_idx = b_batch_idx * b_batch_stride + b_row * b_stride_0 + b_col * b_stride_1 

132 

133 a = tl.load(a_ptr + a_idx, mask=mask) 

134 b = tl.load(b_ptr + b_idx, mask=mask) 

135 c = a * b 

136 

137 c_idx = ( 

138 batch_id * c_batch_stride 

139 + offs_m[:, None] * c_stride_0 

140 + offs_n[None, :] * c_stride_1 

141 ) 

142 tl.store(c_ptr + c_idx, c, mask=mask) 

143 

144 

145def kron(A, B): 

146 if A.dim() == 0 and B.dim() == 0: 

147 return A * B 

148 

149 if A.numel() == 0 or B.numel() == 0: 

150 A_prepared, B_prepared, out_shape = prepare_tensor_for_kron(A, B) 

151 output_dtype = torch.promote_types(A.dtype, B.dtype) 

152 return torch.empty(out_shape, device=A.device, dtype=output_dtype) 

153 

154 if A.dim() == 0: 

155 return A.unsqueeze(0) * B 

156 if B.dim() == 0: 

157 return A * B.unsqueeze(0) 

158 

159 A_prepared, B_prepared, out_shape = prepare_tensor_for_kron(A, B) 

160 M1, N1 = A_prepared.shape[-2:] 

161 M2, N2 = B_prepared.shape[-2:] 

162 M, N = M1 * M2, N1 * N2 

163 

164 batch_size = math.prod(out_shape[:-2]) if out_shape[:-2] else 1 

165 

166 output_dtype = torch.promote_types(A.dtype, B.dtype) 

167 C = torch.empty(out_shape, device=A.device, dtype=output_dtype) 

168 

169 C_reshaped = C.view(-1, M, N) 

170 A_view = A_prepared.reshape(-1, M1, N1) 

171 B_view = B_prepared.reshape(-1, M2, N2) 

172 

173 if not A_view.is_contiguous(): 

174 A_view = A_view.contiguous() 

175 if not B_view.is_contiguous(): 

176 B_view = B_view.contiguous() 

177 

178 batch_indices = torch.empty(batch_size * 2, device=A.device, dtype=torch.int64) 

179 for i in range(batch_size): 

180 a_idx, b_idx = calculate_indices(i, A_prepared.shape, B_prepared.shape) 

181 batch_indices[i * 2] = a_idx 

182 batch_indices[i * 2 + 1] = b_idx 

183 

184 a_batch_stride = M1 * N1 

185 b_batch_stride = M2 * N2 

186 c_batch_stride = M * N 

187 with torch_device_fn.device(A.device): 

188 grid = lambda meta: ( 

189 batch_size 

190 * triton.cdiv(M, meta["BLOCK_M"]) 

191 * triton.cdiv(N, meta["BLOCK_N"]), 

192 ) 

193 

194 kron_kernel[grid]( 

195 A_view, 

196 B_view, 

197 C_reshaped, 

198 batch_indices, 

199 batch_size, 

200 M, 

201 N, 

202 M1, 

203 M2, 

204 N1, 

205 N2, 

206 A_view.stride(1), 

207 A_view.stride(2), 

208 B_view.stride(1), 

209 B_view.stride(2), 

210 C_reshaped.stride(1), 

211 C_reshaped.stride(2), 

212 a_batch_stride, 

213 b_batch_stride, 

214 c_batch_stride, 

215 ) 

216 

217 if A.dim() <= 1 and B.dim() <= 1: 

218 return C.reshape(-1) 

219 

220 return C