Coverage for src/flag_gems/ops/addmm.py: 62%

64 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import broadcastable_to, libentry, libtuner 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15@libentry() 

16@libtuner( 

17 configs=runtime.get_tuned_config("addmm"), 

18 key=["M", "N", "K"], 

19 strategy=["align32", "align32", "align32"], 

20 warmup=5, 

21 rep=10, 

22) 

23@triton.jit(do_not_specialize=["alpha", "beta"]) 

24def addmm_kernel( 

25 a_ptr, 

26 b_ptr, 

27 i_ptr, 

28 c_ptr, 

29 alpha, 

30 beta, 

31 M, 

32 N, 

33 K, 

34 stride_am, 

35 stride_ak, 

36 stride_bk, 

37 stride_bn, 

38 stride_im, 

39 stride_in, 

40 stride_cm, 

41 stride_cn, 

42 BLOCK_SIZE_M: tl.constexpr, 

43 BLOCK_SIZE_N: tl.constexpr, 

44 BLOCK_SIZE_K: tl.constexpr, 

45): 

46 pid_m = tle.program_id(0) 

47 pid_n = tle.program_id(1) 

48 

49 offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 

50 offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

51 offs_k = tl.arange(0, BLOCK_SIZE_K) 

52 a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) 

53 b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) 

54 

55 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

56 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

57 a = tl.load( 

58 a_ptrs, 

59 mask=(offs_am[:, None] < M) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), 

60 other=0.0, 

61 ) 

62 b = tl.load( 

63 b_ptrs, 

64 mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_bn[None, :] < N), 

65 other=0.0, 

66 ) 

67 accumulator += tl.dot(a, b, allow_tf32=False) 

68 a_ptrs += BLOCK_SIZE_K * stride_ak 

69 b_ptrs += BLOCK_SIZE_K * stride_bk 

70 

71 offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 

72 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

73 c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] 

74 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) 

75 i_ptrs = i_ptr + stride_im * offs_cm[:, None] + stride_in * offs_cn[None, :] 

76 bias = tl.load(i_ptrs, mask=c_mask, other=0.0) 

77 

78 accumulator = accumulator * alpha + bias * beta 

79 c = accumulator.to(bias.dtype) 

80 tl.store(c_ptrs, c, mask=c_mask) 

81 

82 

83def addmm(bias, mat1, mat2, *, beta=1, alpha=1): 

84 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions" 

85 assert broadcastable_to( 

86 bias.shape, (mat1.shape[0], mat2.shape[1]) 

87 ), "Incompatible input shape" 

88 M, K = mat1.shape 

89 _, N = mat2.shape 

90 

91 logger.debug( 

92 "GEMS ADDMM, [shape info]: [-, %s, %s, %s](batch, M, N, K), " 

93 "[A column-major]: %s, [B column-major]: %s, [bias column-major]: %s", 

94 M, 

95 N, 

96 K, 

97 mat1.stride(0) == 1, 

98 mat2.stride(0) == 1, 

99 bias.stride(0) == 1, 

100 ) 

101 mat1 = mat1.contiguous() 

102 # mat2 = mat2.contiguous() 

103 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype) 

104 bias = bias.broadcast_to(out.shape) 

105 

106 grid = lambda META: ( 

107 triton.cdiv(M, META["BLOCK_SIZE_M"]), 

108 triton.cdiv(N, META["BLOCK_SIZE_N"]), 

109 ) 

110 with torch_device_fn.device(mat1.device): 

111 addmm_kernel[grid]( 

112 mat1, 

113 mat2, 

114 bias, 

115 out, 

116 alpha, 

117 beta, 

118 M, 

119 N, 

120 K, 

121 mat1.stride(0), 

122 mat1.stride(1), 

123 mat2.stride(0), 

124 mat2.stride(1), 

125 bias.stride(0), 

126 bias.stride(1), 

127 out.stride(0), 

128 out.stride(1), 

129 ) 

130 return out 

131 

132 

133def addmm_out(bias, mat1, mat2, *, beta=1, alpha=1, out=None): 

134 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions" 

135 assert broadcastable_to( 

136 bias.shape, (mat1.shape[0], mat2.shape[1]) 

137 ), "Incompatible input shape" 

138 M, K = mat1.shape 

139 _, N = mat2.shape 

140 if out is None: 

141 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype) 

142 else: 

143 assert out.shape == (M, N), "Incompatible output shape" 

144 logger.debug( 

145 "GEMS ADDMM_OUT, [shape info]: [-, %s, %s, %s](batch, M, N, K), " 

146 "[A column-major]: %s, [B column-major]: %s, [bias column-major]: %s", 

147 M, 

148 N, 

149 K, 

150 mat1.stride(0) == 1, 

151 mat2.stride(0) == 1, 

152 bias.stride(0) == 1, 

153 ) 

154 mat1 = mat1.contiguous() 

155 bias = bias.broadcast_to(out.shape) 

156 

157 grid = lambda META: ( 

158 triton.cdiv(M, META["BLOCK_SIZE_M"]), 

159 triton.cdiv(N, META["BLOCK_SIZE_N"]), 

160 ) 

161 with torch_device_fn.device(mat1.device): 

162 addmm_kernel[grid]( 

163 mat1, 

164 mat2, 

165 bias, 

166 out, 

167 alpha, 

168 beta, 

169 M, 

170 N, 

171 K, 

172 mat1.stride(0), 

173 mat1.stride(1), 

174 mat2.stride(0), 

175 mat2.stride(1), 

176 bias.stride(0), 

177 bias.stride(1), 

178 out.stride(0), 

179 out.stride(1), 

180 ) 

181 return out