Coverage for src/flag_gems/ops/addmm.py: 62%
64 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import broadcastable_to, libentry, libtuner
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger(__name__)
15@libentry()
16@libtuner(
17 configs=runtime.get_tuned_config("addmm"),
18 key=["M", "N", "K"],
19 strategy=["align32", "align32", "align32"],
20 warmup=5,
21 rep=10,
22)
23@triton.jit(do_not_specialize=["alpha", "beta"])
24def addmm_kernel(
25 a_ptr,
26 b_ptr,
27 i_ptr,
28 c_ptr,
29 alpha,
30 beta,
31 M,
32 N,
33 K,
34 stride_am,
35 stride_ak,
36 stride_bk,
37 stride_bn,
38 stride_im,
39 stride_in,
40 stride_cm,
41 stride_cn,
42 BLOCK_SIZE_M: tl.constexpr,
43 BLOCK_SIZE_N: tl.constexpr,
44 BLOCK_SIZE_K: tl.constexpr,
45):
46 pid_m = tle.program_id(0)
47 pid_n = tle.program_id(1)
49 offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
50 offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
51 offs_k = tl.arange(0, BLOCK_SIZE_K)
52 a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
53 b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
55 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
56 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
57 a = tl.load(
58 a_ptrs,
59 mask=(offs_am[:, None] < M) & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
60 other=0.0,
61 )
62 b = tl.load(
63 b_ptrs,
64 mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_bn[None, :] < N),
65 other=0.0,
66 )
67 accumulator += tl.dot(a, b, allow_tf32=False)
68 a_ptrs += BLOCK_SIZE_K * stride_ak
69 b_ptrs += BLOCK_SIZE_K * stride_bk
71 offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
72 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
73 c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
74 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
75 i_ptrs = i_ptr + stride_im * offs_cm[:, None] + stride_in * offs_cn[None, :]
76 bias = tl.load(i_ptrs, mask=c_mask, other=0.0)
78 accumulator = accumulator * alpha + bias * beta
79 c = accumulator.to(bias.dtype)
80 tl.store(c_ptrs, c, mask=c_mask)
83def addmm(bias, mat1, mat2, *, beta=1, alpha=1):
84 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions"
85 assert broadcastable_to(
86 bias.shape, (mat1.shape[0], mat2.shape[1])
87 ), "Incompatible input shape"
88 M, K = mat1.shape
89 _, N = mat2.shape
91 logger.debug(
92 "GEMS ADDMM, [shape info]: [-, %s, %s, %s](batch, M, N, K), "
93 "[A column-major]: %s, [B column-major]: %s, [bias column-major]: %s",
94 M,
95 N,
96 K,
97 mat1.stride(0) == 1,
98 mat2.stride(0) == 1,
99 bias.stride(0) == 1,
100 )
101 mat1 = mat1.contiguous()
102 # mat2 = mat2.contiguous()
103 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype)
104 bias = bias.broadcast_to(out.shape)
106 grid = lambda META: (
107 triton.cdiv(M, META["BLOCK_SIZE_M"]),
108 triton.cdiv(N, META["BLOCK_SIZE_N"]),
109 )
110 with torch_device_fn.device(mat1.device):
111 addmm_kernel[grid](
112 mat1,
113 mat2,
114 bias,
115 out,
116 alpha,
117 beta,
118 M,
119 N,
120 K,
121 mat1.stride(0),
122 mat1.stride(1),
123 mat2.stride(0),
124 mat2.stride(1),
125 bias.stride(0),
126 bias.stride(1),
127 out.stride(0),
128 out.stride(1),
129 )
130 return out
133def addmm_out(bias, mat1, mat2, *, beta=1, alpha=1, out=None):
134 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions"
135 assert broadcastable_to(
136 bias.shape, (mat1.shape[0], mat2.shape[1])
137 ), "Incompatible input shape"
138 M, K = mat1.shape
139 _, N = mat2.shape
140 if out is None:
141 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype)
142 else:
143 assert out.shape == (M, N), "Incompatible output shape"
144 logger.debug(
145 "GEMS ADDMM_OUT, [shape info]: [-, %s, %s, %s](batch, M, N, K), "
146 "[A column-major]: %s, [B column-major]: %s, [bias column-major]: %s",
147 M,
148 N,
149 K,
150 mat1.stride(0) == 1,
151 mat2.stride(0) == 1,
152 bias.stride(0) == 1,
153 )
154 mat1 = mat1.contiguous()
155 bias = bias.broadcast_to(out.shape)
157 grid = lambda META: (
158 triton.cdiv(M, META["BLOCK_SIZE_M"]),
159 triton.cdiv(N, META["BLOCK_SIZE_N"]),
160 )
161 with torch_device_fn.device(mat1.device):
162 addmm_kernel[grid](
163 mat1,
164 mat2,
165 bias,
166 out,
167 alpha,
168 beta,
169 M,
170 N,
171 K,
172 mat1.stride(0),
173 mat1.stride(1),
174 mat2.stride(0),
175 mat2.stride(1),
176 bias.stride(0),
177 bias.stride(1),
178 out.stride(0),
179 out.stride(1),
180 )
181 return out