Coverage for src/flag_gems/runtime/backend/_cambricon/ops/addmm.py: 0%

78 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-21 14:31 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import broadcastable_to, libentry, libtuner 

10 

11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

12 

13 

14@libentry() 

15@libtuner( 

16 configs=runtime.get_tuned_config("addmm"), 

17 key=["M", "N", "K", "stride_am", "stride_bk"], 

18 strategy=["align32", "align32", "align32", "align32", "align32"], 

19 warmup=5, 

20 rep=10, 

21) 

22@triton.heuristics( 

23 { 

24 "EVEN_M": lambda args: args["M"] % args["BLOCK_SIZE_M"] == 0, 

25 "EVEN_N": lambda args: args["N"] % args["BLOCK_SIZE_N"] == 0, 

26 "EVEN_K": lambda args: args["K"] % args["BLOCK_SIZE_K"] == 0, 

27 "BIAS_BROADCAST_M": lambda args: args["stride_im"] == 0, 

28 "BIAS_BROADCAST_N": lambda args: args["stride_in"] == 0, 

29 } 

30) 

31@triton.jit(do_not_specialize=["alpha", "beta"]) 

32def addmm_kernel( 

33 a_ptr, 

34 b_ptr, 

35 i_ptr, 

36 c_ptr, 

37 alpha, 

38 beta, 

39 M, 

40 N, 

41 K, 

42 stride_am, 

43 stride_ak, 

44 stride_bk, 

45 stride_bn, 

46 stride_im, 

47 stride_in, 

48 stride_cm, 

49 stride_cn, 

50 BLOCK_SIZE_M: tl.constexpr, 

51 BLOCK_SIZE_N: tl.constexpr, 

52 BLOCK_SIZE_K: tl.constexpr, 

53 BIAS_BROADCAST_M: tl.constexpr, 

54 BIAS_BROADCAST_N: tl.constexpr, 

55 EVEN_M: tl.constexpr, 

56 EVEN_N: tl.constexpr, 

57 EVEN_K: tl.constexpr, 

58): 

59 pid_m = tl.program_id(0) 

60 pid_n = tl.program_id(1) 

61 

62 offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 

63 offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

64 offs_k = tl.arange(0, BLOCK_SIZE_K) 

65 a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) 

66 b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) 

67 

68 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

69 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

70 k_remaining = K - k * BLOCK_SIZE_K 

71 if EVEN_M and EVEN_K: 

72 a = tl.load(a_ptrs) 

73 else: 

74 a = tl.load( 

75 a_ptrs, 

76 mask=(offs_am[:, None] < M) & (offs_k[None, :] < k_remaining), 

77 other=0.0, 

78 ) 

79 if EVEN_N and EVEN_K: 

80 b = tl.load(b_ptrs) 

81 else: 

82 b = tl.load( 

83 b_ptrs, 

84 mask=(offs_k[:, None] < k_remaining) & (offs_bn[None, :] < N), 

85 other=0.0, 

86 ) 

87 accumulator += tl.dot(a, b, allow_tf32=False) 

88 a_ptrs += BLOCK_SIZE_K * stride_ak 

89 b_ptrs += BLOCK_SIZE_K * stride_bk 

90 

91 offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 

92 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

93 c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] 

94 

95 if BIAS_BROADCAST_M: 

96 stride_im = 0 

97 

98 if BIAS_BROADCAST_N: 

99 stride_in = 0 

100 

101 i_ptrs = i_ptr + stride_im * offs_cm[:, None] + stride_in * offs_cn[None, :] 

102 

103 if EVEN_M and EVEN_N: 

104 bias = tl.load(i_ptrs) 

105 else: 

106 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) 

107 bias = tl.load(i_ptrs, mask=c_mask, other=0.0) 

108 

109 accumulator = accumulator * alpha + bias * beta 

110 c = accumulator.to(bias.dtype) 

111 

112 if EVEN_M and EVEN_N: 

113 tl.store(c_ptrs, c) 

114 else: 

115 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) 

116 tl.store(c_ptrs, c, mask=c_mask) 

117 

118 

119def addmm(bias, mat1, mat2, *, beta=1, alpha=1): 

120 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions" 

121 assert broadcastable_to( 

122 bias.shape, (mat1.shape[0], mat2.shape[1]) 

123 ), "Incompatible input shape" 

124 M, K = mat1.shape 

125 _, N = mat2.shape 

126 

127 logger.debug( 

128 "GEMS_CAMBRICON ADDMM, [shape info]: [-, %s, %s, %s](batch, M, N, K), " 

129 "[A column-major]: %s, [B column-major]: %s, [bias column-major]: %s", 

130 M, 

131 N, 

132 K, 

133 mat1.stride(0) == 1, 

134 mat2.stride(0) == 1, 

135 bias.stride(0) == 1, 

136 ) 

137 mat1 = mat1.contiguous() 

138 # mat2 = mat2.contiguous() 

139 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype) 

140 bias = bias.broadcast_to(out.shape) 

141 

142 grid = lambda META: ( 

143 triton.cdiv(M, META["BLOCK_SIZE_M"]), 

144 triton.cdiv(N, META["BLOCK_SIZE_N"]), 

145 ) 

146 with torch_device_fn.device(mat1.device): 

147 addmm_kernel[grid]( 

148 mat1, 

149 mat2, 

150 bias, 

151 out, 

152 alpha, 

153 beta, 

154 M, 

155 N, 

156 K, 

157 mat1.stride(0), 

158 mat1.stride(1), 

159 mat2.stride(0), 

160 mat2.stride(1), 

161 bias.stride(0), 

162 bias.stride(1), 

163 out.stride(0), 

164 out.stride(1), 

165 ) 

166 return out 

167 

168 

169def addmm_out(bias, mat1, mat2, *, beta=1, alpha=1, out=None): 

170 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions" 

171 assert broadcastable_to( 

172 bias.shape, (mat1.shape[0], mat2.shape[1]) 

173 ), "Incompatible input shape" 

174 M, K = mat1.shape 

175 _, N = mat2.shape 

176 if out is None: 

177 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype) 

178 else: 

179 assert out.shape == (M, N), "Incompatible output shape" 

180 logger.debug( 

181 "GEMS_CAMBRICON ADDMM_OUT, [shape info]: [-, %s, %s, %s](batch, M, N, K), " 

182 "[A column-major]: %s, [B column-major]: %s, [bias column-major]: %s", 

183 M, 

184 N, 

185 K, 

186 mat1.stride(0) == 1, 

187 mat2.stride(0) == 1, 

188 bias.stride(0) == 1, 

189 ) 

190 mat1 = mat1.contiguous() 

191 bias = bias.broadcast_to(out.shape) 

192 

193 grid = lambda META: ( 

194 triton.cdiv(M, META["BLOCK_SIZE_M"]), 

195 triton.cdiv(N, META["BLOCK_SIZE_N"]), 

196 ) 

197 with torch_device_fn.device(mat1.device): 

198 addmm_kernel[grid]( 

199 mat1, 

200 mat2, 

201 bias, 

202 out, 

203 alpha, 

204 beta, 

205 M, 

206 N, 

207 K, 

208 mat1.stride(0), 

209 mat1.stride(1), 

210 mat2.stride(0), 

211 mat2.stride(1), 

212 bias.stride(0), 

213 bias.stride(1), 

214 out.stride(0), 

215 out.stride(1), 

216 ) 

217 return out