Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/addmm.py: 0%

85 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1import logging 

2import os 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8# from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import broadcastable_to, libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

14 

15 

16autotune_decorator = triton.autotune( 

17 configs=[], 

18 generate_configs="addmm", 

19 key=["M", "N", "K"], 

20) 

21 

22 

23KLX_USE_AUTOTUNE = os.environ.get("KLX_USE_AUTOTUNE", "1") == "1" 

24 

25if not KLX_USE_AUTOTUNE: 

26 

27 def heur_block_m(args): 

28 M = args["M"] 

29 if M == 1: 

30 return 2 

31 if M <= 32: 

32 return M 

33 

34 return 128 

35 

36 def heur_block_n(args): 

37 N = args["N"] 

38 if N == 1: 

39 return 2 

40 if N <= 32: 

41 return N 

42 return 128 

43 

44 def heur_block_k(args): 

45 K = args["K"] 

46 return min(K, 128) 

47 

48 autotune_decorator = triton.heuristics( 

49 { 

50 "BLOCK_SIZE_M": heur_block_m, 

51 "BLOCK_SIZE_N": heur_block_n, 

52 "BLOCK_SIZE_K": heur_block_k, 

53 } 

54 ) 

55 

56 

57@libentry() 

58@autotune_decorator 

59@triton.jit(do_not_specialize=["alpha", "beta"]) 

60def addmm_kernel( 

61 a_ptr, 

62 b_ptr, 

63 i_ptr, 

64 c_ptr, 

65 alpha, 

66 beta, 

67 M, 

68 N, 

69 K, 

70 stride_am, 

71 stride_ak, 

72 stride_bk, 

73 stride_bn, 

74 stride_im, 

75 stride_in, 

76 stride_cm, 

77 stride_cn, 

78 BLOCK_SIZE_M: tl.constexpr, 

79 BLOCK_SIZE_N: tl.constexpr, 

80 BLOCK_SIZE_K: tl.constexpr, 

81): 

82 pid_m = tle.program_id(0) 

83 pid_n = tle.program_id(1) 

84 

85 offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 

86 offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

87 offs_k = tl.arange(0, BLOCK_SIZE_K) 

88 a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) 

89 b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) 

90 

91 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

92 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

93 a = tl.load(a_ptrs) 

94 b = tl.load(b_ptrs) 

95 accumulator += tl.dot(a, b, allow_tf32=False) 

96 a_ptrs += BLOCK_SIZE_K * stride_ak 

97 b_ptrs += BLOCK_SIZE_K * stride_bk 

98 

99 offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 

100 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

101 c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] 

102 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) 

103 i_ptrs = i_ptr + stride_im * offs_cm[:, None] + stride_in * offs_cn[None, :] 

104 bias = tl.load(i_ptrs, mask=c_mask, other=0.0) 

105 

106 accumulator = accumulator * alpha + bias * beta 

107 c = accumulator.to(bias.dtype) 

108 tl.store(c_ptrs, c, mask=c_mask) 

109 

110 

111def addmm(bias, mat1, mat2, *, beta=1.0, alpha=1.0): 

112 logger.debug("GEMS ADDMM") 

113 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions" 

114 assert broadcastable_to( 

115 bias.shape, (mat1.shape[0], mat2.shape[1]) 

116 ), "Incompatible input shape" 

117 M, K = mat1.shape 

118 _, N = mat2.shape 

119 

120 mat1 = mat1.contiguous() 

121 # mat2 = mat2.contiguous() 

122 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype) 

123 bias = bias.broadcast_to(out.shape) 

124 

125 grid = lambda META: ( 

126 triton.cdiv(M, META["BLOCK_SIZE_M"]), 

127 triton.cdiv(N, META["BLOCK_SIZE_N"]), 

128 ) 

129 with torch_device_fn.device(mat1.device): 

130 addmm_kernel[grid]( 

131 mat1, 

132 mat2, 

133 bias, 

134 out, 

135 alpha, 

136 beta, 

137 M, 

138 N, 

139 K, 

140 mat1.stride(0), 

141 mat1.stride(1), 

142 mat2.stride(0), 

143 mat2.stride(1), 

144 bias.stride(0), 

145 bias.stride(1), 

146 out.stride(0), 

147 out.stride(1), 

148 ) 

149 return out 

150 

151 

152def addmm_out(bias, mat1, mat2, *, beta=1.0, alpha=1.0, out=None): 

153 logger.debug("GEMS ADDMM OUT") 

154 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions" 

155 assert broadcastable_to( 

156 bias.shape, (mat1.shape[0], mat2.shape[1]) 

157 ), "Incompatible input shape" 

158 M, K = mat1.shape 

159 _, N = mat2.shape 

160 if out is None: 

161 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype) 

162 else: 

163 assert out.shape == (M, N), "Incompatible output shape" 

164 

165 mat1 = mat1.contiguous() 

166 bias = bias.broadcast_to(out.shape) 

167 

168 grid = lambda META: ( 

169 triton.cdiv(M, META["BLOCK_SIZE_M"]), 

170 triton.cdiv(N, META["BLOCK_SIZE_N"]), 

171 ) 

172 with torch_device_fn.device(mat1.device): 

173 addmm_kernel[grid]( 

174 mat1, 

175 mat2, 

176 bias, 

177 out, 

178 alpha, 

179 beta, 

180 M, 

181 N, 

182 K, 

183 mat1.stride(0), 

184 mat1.stride(1), 

185 mat2.stride(0), 

186 mat2.stride(1), 

187 bias.stride(0), 

188 bias.stride(1), 

189 out.stride(0), 

190 out.stride(1), 

191 ) 

192 return out