Coverage for src/flag_gems/runtime/backend/_cambricon/fused/outer.py: 0%

75 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-10 02:30 +0800

1import logging 

2 

3import torch 

4import triton 

5from triton import language as tl 

6 

7from flag_gems.utils import libentry 

8 

9from ..ops import mv 

10from ..utils import TOTAL_CORE_NUM 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15# The outer kernel requires 3 parameters to determine the splitting method, 

16# but during actual tuning, you only need to determine the total size of the split blocks. 

17# Based on the second input length N and the total size of the split blocks, 

18# the 3 parameters that determine the splitting method can be calculated. 

19# Therefore, the conversion between these two is achieved through early_config_prune. 

20def early_config_prune(configs, named_args, **kwargs): 

21 if "N" in kwargs: 

22 N = kwargs["N"] 

23 else: 

24 N = named_args["N"] 

25 

26 new_configs = [] 

27 for config in configs: 

28 tile_size = config.kwargs["tile_size"] 

29 block_n = min(tile_size, N) 

30 block_m = triton.cdiv(tile_size, block_n) 

31 new_config = triton.Config( 

32 {"BLOCK_M": block_m, "BLOCK_N": block_n, "NEED_LOOP_N": block_n < N}, 

33 num_stages=config.num_stages, 

34 num_warps=config.num_warps, 

35 ) 

36 new_configs.append(new_config) 

37 

38 return new_configs 

39 

40 

41@libentry() 

42@triton.autotune( 

43 configs=[ 

44 triton.Config({"tile_size": 1024}, num_stages=3, num_warps=1), 

45 triton.Config({"tile_size": 2048}, num_stages=3, num_warps=1), 

46 triton.Config({"tile_size": 4096}, num_stages=3, num_warps=1), 

47 triton.Config({"tile_size": 8192}, num_stages=3, num_warps=1), 

48 triton.Config({"tile_size": 16384}, num_stages=3, num_warps=1), 

49 triton.Config({"tile_size": 21760}, num_stages=3, num_warps=1), 

50 triton.Config({"tile_size": 32768}, num_stages=3, num_warps=1), 

51 ], 

52 key=["M", "N"], 

53 prune_configs_by={"early_config_prune": early_config_prune}, 

54) 

55@triton.jit 

56def outer_kernel( 

57 lhs, 

58 rhs, 

59 res, 

60 M, 

61 N, 

62 BLOCK_M: tl.constexpr, 

63 BLOCK_N: tl.constexpr, 

64 NEED_LOOP_N: tl.constexpr, 

65): 

66 pid = tl.program_id(0) 

67 num_jobs = tl.num_programs(axis=0) 

68 

69 m_tasks_num = tl.cdiv(M, BLOCK_M) 

70 n_tasks_num = tl.cdiv(N, BLOCK_N) 

71 total_tasks_num = m_tasks_num * n_tasks_num 

72 

73 if NEED_LOOP_N: 

74 for task_id in range(pid, total_tasks_num, num_jobs): 

75 start_m = task_id // n_tasks_num 

76 start_n = task_id % n_tasks_num 

77 

78 offset_m = tl.arange(0, BLOCK_M) + start_m * BLOCK_M 

79 lhs_val = tl.load(lhs + offset_m, mask=offset_m < M) 

80 

81 offset_n = tl.arange(0, BLOCK_N) + start_n * BLOCK_N 

82 rhs_val = tl.load(rhs + offset_n, mask=offset_n < N) 

83 

84 res_val = lhs_val[:, None] * rhs_val[None, :] 

85 

86 offset_r = offset_m[:, None] * N + offset_n[None, :] 

87 tl.store( 

88 res + offset_r, 

89 res_val, 

90 mask=(offset_m[:, None] < M) & (offset_n[None, :] < N), 

91 ) 

92 else: 

93 offset_n = tl.arange(0, BLOCK_N) 

94 rhs_val = tl.load(rhs + offset_n) 

95 for task_id in range(pid, total_tasks_num, num_jobs): 

96 start_m = task_id // n_tasks_num 

97 

98 offset_m = tl.arange(0, BLOCK_M) + start_m * BLOCK_M 

99 lhs_val = tl.load(lhs + offset_m, mask=offset_m < M) 

100 

101 res_val = lhs_val[:, None] * rhs_val[None, :] 

102 

103 offset_r = offset_m[:, None] * N + offset_n[None, :] 

104 tl.store( 

105 res + offset_r, 

106 res_val, 

107 mask=(offset_m[:, None] < M) & (offset_n[None, :] < N), 

108 ) 

109 

110 

111def outer_(lhs, rhs): 

112 m = lhs.shape[0] 

113 n = rhs.shape[0] 

114 res_shape = [m, n] 

115 res = torch.empty(res_shape, dtype=lhs.dtype, device="mlu") 

116 grid = lambda META: ( 

117 min( 

118 triton.cdiv(m, META["BLOCK_M"]) * triton.cdiv(n, META["BLOCK_N"]), 

119 TOTAL_CORE_NUM, 

120 ), 

121 ) 

122 outer_kernel[grid](lhs, rhs, res, m, n) 

123 return res 

124 

125 

126class Outer(torch.autograd.Function): 

127 @staticmethod 

128 def forward(ctx, inp, weight): 

129 logger.debug("GEMS_CAMBRICON OUTER") 

130 assert inp.ndim == 1 and weight.ndim == 1, "Invalid input" 

131 out = outer_(inp, weight) 

132 ctx.save_for_backward(inp, weight) 

133 return out 

134 

135 @staticmethod 

136 def backward(ctx, out_grad): 

137 logger.debug("GEMS_CAMBRICON OUTER VJP") 

138 assert out_grad.ndim == 2, "invalide out_grad shape" 

139 

140 inp, weight = ctx.saved_tensors 

141 

142 inp_grad = mv(out_grad, weight) 

143 weight_grad = mv(out_grad.t(), inp) 

144 

145 return inp_grad, weight_grad 

146 

147 

148def outer(inp, weight): 

149 return Outer.apply(inp, weight)