Coverage for src/flag_gems/runtime/backend/_ascend/ops/outer.py: 0%

74 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-16 02:02 +0800

1import logging 

2 

3import torch 

4import triton 

5from triton import language as tl 

6 

7from flag_gems.ops.mv import mv 

8from flag_gems.utils import libentry 

9 

10logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

11 

12 

13# The outer kernel requires 3 parameters to determine the splitting method, 

14# but during actual tuning, you only need to determine the total size of the split blocks. 

15# Based on the second input length N and the total size of the split blocks, 

16# the 3 parameters that determine the splitting method can be calculated. 

17# Therefore, the conversion between these two is achieved through early_config_prune. 

18def early_config_prune(configs, named_args, **kwargs): 

19 if "N" in kwargs: 

20 N = kwargs["N"] 

21 else: 

22 N = named_args["N"] 

23 

24 new_configs = [] 

25 for config in configs: 

26 tile_size = config.kwargs["tile_size"] 

27 block_n = min(tile_size, N) 

28 block_m = triton.cdiv(tile_size, block_n) 

29 new_config = triton.Config( 

30 {"BLOCK_M": block_m, "BLOCK_N": block_n, "NEED_LOOP_N": block_n < N}, 

31 num_stages=config.num_stages, 

32 num_warps=config.num_warps, 

33 ) 

34 new_configs.append(new_config) 

35 

36 return new_configs 

37 

38 

39@libentry() 

40@triton.autotune( 

41 configs=[ 

42 triton.Config({"tile_size": 1024}, num_stages=3, num_warps=1), 

43 triton.Config({"tile_size": 2048}, num_stages=3, num_warps=1), 

44 triton.Config({"tile_size": 4096}, num_stages=3, num_warps=1), 

45 triton.Config({"tile_size": 8192}, num_stages=3, num_warps=1), 

46 triton.Config({"tile_size": 16384}, num_stages=3, num_warps=1), 

47 triton.Config({"tile_size": 21760}, num_stages=3, num_warps=1), 

48 triton.Config({"tile_size": 32768}, num_stages=3, num_warps=1), 

49 ], 

50 key=["M", "N"], 

51 prune_configs_by={"early_config_prune": early_config_prune}, 

52) 

53@triton.jit 

54def outer_kernel( 

55 lhs, 

56 rhs, 

57 res, 

58 M, 

59 N, 

60 BLOCK_M: tl.constexpr, 

61 BLOCK_N: tl.constexpr, 

62 NEED_LOOP_N: tl.constexpr, 

63): 

64 pid = tl.program_id(0) 

65 num_jobs = tl.num_programs(axis=0) 

66 

67 m_tasks_num = tl.cdiv(M, BLOCK_M) 

68 n_tasks_num = tl.cdiv(N, BLOCK_N) 

69 total_tasks_num = m_tasks_num * n_tasks_num 

70 

71 if NEED_LOOP_N: 

72 for task_id in range(pid, total_tasks_num, num_jobs): 

73 start_m = task_id // n_tasks_num 

74 start_n = task_id % n_tasks_num 

75 

76 offset_m = tl.arange(0, BLOCK_M) + start_m * BLOCK_M 

77 lhs_val = tl.load(lhs + offset_m, mask=offset_m < M) 

78 

79 offset_n = tl.arange(0, BLOCK_N) + start_n * BLOCK_N 

80 rhs_val = tl.load(rhs + offset_n, mask=offset_n < N) 

81 

82 res_val = lhs_val[:, None] * rhs_val[None, :] 

83 

84 offset_r = offset_m[:, None] * N + offset_n[None, :] 

85 tl.store( 

86 res + offset_r, 

87 res_val, 

88 mask=(offset_m[:, None] < M) & (offset_n[None, :] < N), 

89 ) 

90 else: 

91 offset_n = tl.arange(0, BLOCK_N) 

92 rhs_val = tl.load(rhs + offset_n) 

93 for task_id in range(pid, total_tasks_num, num_jobs): 

94 start_m = task_id // n_tasks_num 

95 

96 offset_m = tl.arange(0, BLOCK_M) + start_m * BLOCK_M 

97 lhs_val = tl.load(lhs + offset_m, mask=offset_m < M) 

98 

99 res_val = lhs_val[:, None] * rhs_val[None, :] 

100 

101 offset_r = offset_m[:, None] * N + offset_n[None, :] 

102 tl.store( 

103 res + offset_r, 

104 res_val, 

105 mask=(offset_m[:, None] < M) & (offset_n[None, :] < N), 

106 ) 

107 

108 

109def outer_(lhs, rhs): 

110 m = lhs.shape[0] 

111 n = rhs.shape[0] 

112 res_shape = [m, n] 

113 res = torch.empty(res_shape, dtype=lhs.dtype, device="npu") 

114 grid = lambda META: ( 

115 min( 

116 triton.cdiv(m, META["BLOCK_M"]) * triton.cdiv(n, META["BLOCK_N"]), 

117 65535, 

118 ), 

119 ) 

120 outer_kernel[grid](lhs, rhs, res, m, n) 

121 return res 

122 

123 

124class Outer(torch.autograd.Function): 

125 @staticmethod 

126 def forward(ctx, inp, weight): 

127 logger.debug("GEMS_ASCEND OUTER") 

128 assert inp.ndim == 1 and weight.ndim == 1, "Invalid input" 

129 out = outer_(inp, weight) 

130 ctx.save_for_backward(inp, weight) 

131 return out 

132 

133 @staticmethod 

134 def backward(ctx, out_grad): 

135 logger.debug("GEMS_ASCEND OUTER VJP") 

136 assert out_grad.ndim == 2, "invalide out_grad shape" 

137 

138 inp, weight = ctx.saved_tensors 

139 

140 inp_grad = mv(out_grad, weight) 

141 weight_grad = mv(out_grad.t(), inp) 

142 

143 return inp_grad, weight_grad 

144 

145 

146def outer(inp, weight): 

147 return Outer.apply(inp, weight)