Coverage for src/flag_gems/runtime/backend/_sunrise/ops/addmm.py: 0%

85 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import broadcastable_to, libentry, libtuner 

10from flag_gems.utils import triton_lang_extension as ext 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15@libentry() 

16@libtuner( 

17 configs=runtime.get_tuned_config("addmm"), 

18 key=["M", "N", "K"], 

19 strategy=["align32", "align32", "align32"], 

20 warmup=5, 

21 rep=10, 

22) 

23@triton.jit(do_not_specialize=["alpha", "beta"]) 

24def addmm_kernel( 

25 a_ptr, 

26 b_ptr, 

27 i_ptr, 

28 c_ptr, 

29 alpha, 

30 beta, 

31 M, 

32 N, 

33 K, 

34 stride_am, 

35 stride_ak, 

36 stride_bk, 

37 stride_bn, 

38 stride_im, 

39 stride_in, 

40 stride_cm, 

41 stride_cn, 

42 BLOCK_SIZE_M: tl.constexpr, 

43 BLOCK_SIZE_N: tl.constexpr, 

44 BLOCK_SIZE_K: tl.constexpr, 

45 IS_FP64: tl.constexpr = False, 

46): 

47 pid_m = ext.program_id(0) 

48 pid_n = ext.program_id(1) 

49 

50 offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 

51 offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

52 offs_k = tl.arange(0, BLOCK_SIZE_K) 

53 a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) 

54 b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) 

55 

56 if IS_FP64: 

57 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float64) 

58 else: 

59 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

60 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

61 a = tl.load( 

62 a_ptrs, 

63 mask=(offs_am[:, None] < M) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), 

64 other=0.0, 

65 ) 

66 b = tl.load( 

67 b_ptrs, 

68 mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_bn[None, :] < N), 

69 other=0.0, 

70 ) 

71 if IS_FP64: 

72 a = a.to(tl.float32) 

73 b = b.to(tl.float32) 

74 accumulator += tl.dot(a, b, allow_tf32=False) 

75 a_ptrs += BLOCK_SIZE_K * stride_ak 

76 b_ptrs += BLOCK_SIZE_K * stride_bk 

77 

78 offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 

79 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

80 c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] 

81 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) 

82 i_ptrs = i_ptr + stride_im * offs_cm[:, None] + stride_in * offs_cn[None, :] 

83 bias = tl.load(i_ptrs, mask=c_mask, other=0.0) 

84 

85 accumulator = accumulator * alpha + bias * beta 

86 c = accumulator.to(bias.dtype) 

87 tl.store(c_ptrs, c, mask=c_mask) 

88 

89 

90def addmm(bias, mat1, mat2, *, beta=1, alpha=1): 

91 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions" 

92 assert broadcastable_to( 

93 bias.shape, (mat1.shape[0], mat2.shape[1]) 

94 ), "Incompatible input shape" 

95 M, K = mat1.shape 

96 _, N = mat2.shape 

97 

98 logger.debug( 

99 "GEMS ADDMM, [shape info]: [-, %s, %s, %s](batch, M, N, K), " 

100 "[A column-major]: %s, [B column-major]: %s, [bias column-major]: %s", 

101 M, 

102 N, 

103 K, 

104 mat1.stride(0) == 1, 

105 mat2.stride(0) == 1, 

106 bias.stride(0) == 1, 

107 ) 

108 mat1 = mat1.contiguous() 

109 # mat2 = mat2.contiguous() 

110 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype) 

111 bias = bias.broadcast_to(out.shape) 

112 

113 grid = lambda META: ( 

114 triton.cdiv(M, META["BLOCK_SIZE_M"]), 

115 triton.cdiv(N, META["BLOCK_SIZE_N"]), 

116 ) 

117 with torch_device_fn.device(mat1.device): 

118 addmm_kernel[grid]( 

119 mat1, 

120 mat2, 

121 bias, 

122 out, 

123 alpha, 

124 beta, 

125 M, 

126 N, 

127 K, 

128 mat1.stride(0), 

129 mat1.stride(1), 

130 mat2.stride(0), 

131 mat2.stride(1), 

132 bias.stride(0), 

133 bias.stride(1), 

134 out.stride(0), 

135 out.stride(1), 

136 IS_FP64=mat1.dtype == torch.float64, 

137 ) 

138 return out 

139 

140 

141def addmm_out(bias, mat1, mat2, *, beta=1, alpha=1, out=None): 

142 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions" 

143 assert broadcastable_to( 

144 bias.shape, (mat1.shape[0], mat2.shape[1]) 

145 ), "Incompatible input shape" 

146 M, K = mat1.shape 

147 _, N = mat2.shape 

148 if out is None: 

149 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype) 

150 else: 

151 assert out.shape == (M, N), "Incompatible output shape" 

152 logger.debug( 

153 "GEMS ADDMM_OUT, [shape info]: [-, %s, %s, %s](batch, M, N, K), " 

154 "[A column-major]: %s, [B column-major]: %s, [bias column-major]: %s", 

155 M, 

156 N, 

157 K, 

158 mat1.stride(0) == 1, 

159 mat2.stride(0) == 1, 

160 bias.stride(0) == 1, 

161 ) 

162 mat1 = mat1.contiguous() 

163 bias = bias.broadcast_to(out.shape) 

164 

165 grid = lambda META: ( 

166 triton.cdiv(M, META["BLOCK_SIZE_M"]), 

167 triton.cdiv(N, META["BLOCK_SIZE_N"]), 

168 ) 

169 with torch_device_fn.device(mat1.device): 

170 addmm_kernel[grid]( 

171 mat1, 

172 mat2, 

173 bias, 

174 out, 

175 alpha, 

176 beta, 

177 M, 

178 N, 

179 K, 

180 mat1.stride(0), 

181 mat1.stride(1), 

182 mat2.stride(0), 

183 mat2.stride(1), 

184 bias.stride(0), 

185 bias.stride(1), 

186 out.stride(0), 

187 out.stride(1), 

188 IS_FP64=mat1.dtype == torch.float64, 

189 ) 

190 return out 

191 

192 

193def addmm_dtype(bias, mat1, mat2, out_dtype, *, beta=1, alpha=1): 

194 logger.debug("GEMS ADDMM_DTYPE") 

195 out = torch.empty( 

196 (mat1.shape[0], mat2.shape[1]), 

197 device=mat1.device, 

198 dtype=out_dtype, 

199 ) 

200 return addmm_dtype_out(bias, mat1, mat2, out_dtype, beta=beta, alpha=alpha, out=out) 

201 

202 

203def addmm_dtype_out(bias, mat1, mat2, out_dtype, *, beta=1, alpha=1, out): 

204 logger.debug("GEMS ADDMM_DTYPE_OUT") 

205 if mat1.dtype != mat2.dtype: 

206 raise RuntimeError( 

207 f"mat1 and mat2 must have the same dtype, but got {mat1.dtype} and {mat2.dtype}" 

208 ) 

209 if out.dtype != out_dtype: 

210 raise RuntimeError( 

211 "out_dtype must be the same as the dtype of the provided out tensor" 

212 ) 

213 if not ( 

214 out_dtype == mat1.dtype 

215 or ( 

216 out_dtype == torch.float32 and mat1.dtype in (torch.float16, torch.bfloat16) 

217 ) 

218 ): 

219 raise RuntimeError( 

220 "out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs" 

221 ) 

222 if bias.dtype != out_dtype and bias.dtype != mat1.dtype: 

223 raise RuntimeError("self dtype must match either out_dtype or mat1 dtype") 

224 

225 bias_c = bias.to(out_dtype) 

226 return addmm_out(bias_c, mat1, mat2, beta=beta, alpha=alpha, out=out)