Coverage for src/flag_gems/runtime/backend/_cambricon/ops/addmm.py: 0%
78 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import broadcastable_to, libentry, libtuner
11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
14@libentry()
15@libtuner(
16 configs=runtime.get_tuned_config("addmm"),
17 key=["M", "N", "K", "stride_am", "stride_bk"],
18 strategy=["align32", "align32", "align32", "align32", "align32"],
19 warmup=5,
20 rep=10,
21)
22@triton.heuristics(
23 {
24 "EVEN_M": lambda args: args["M"] % args["BLOCK_SIZE_M"] == 0,
25 "EVEN_N": lambda args: args["N"] % args["BLOCK_SIZE_N"] == 0,
26 "EVEN_K": lambda args: args["K"] % args["BLOCK_SIZE_K"] == 0,
27 "BIAS_BROADCAST_M": lambda args: args["stride_im"] == 0,
28 "BIAS_BROADCAST_N": lambda args: args["stride_in"] == 0,
29 }
30)
31@triton.jit(do_not_specialize=["alpha", "beta"])
32def addmm_kernel(
33 a_ptr,
34 b_ptr,
35 i_ptr,
36 c_ptr,
37 alpha,
38 beta,
39 M,
40 N,
41 K,
42 stride_am,
43 stride_ak,
44 stride_bk,
45 stride_bn,
46 stride_im,
47 stride_in,
48 stride_cm,
49 stride_cn,
50 BLOCK_SIZE_M: tl.constexpr,
51 BLOCK_SIZE_N: tl.constexpr,
52 BLOCK_SIZE_K: tl.constexpr,
53 BIAS_BROADCAST_M: tl.constexpr,
54 BIAS_BROADCAST_N: tl.constexpr,
55 EVEN_M: tl.constexpr,
56 EVEN_N: tl.constexpr,
57 EVEN_K: tl.constexpr,
58):
59 pid_m = tl.program_id(0)
60 pid_n = tl.program_id(1)
62 offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
63 offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
64 offs_k = tl.arange(0, BLOCK_SIZE_K)
65 a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
66 b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
68 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
69 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
70 k_remaining = K - k * BLOCK_SIZE_K
71 if EVEN_M and EVEN_K:
72 a = tl.load(a_ptrs)
73 else:
74 a = tl.load(
75 a_ptrs,
76 mask=(offs_am[:, None] < M) & (offs_k[None, :] < k_remaining),
77 other=0.0,
78 )
79 if EVEN_N and EVEN_K:
80 b = tl.load(b_ptrs)
81 else:
82 b = tl.load(
83 b_ptrs,
84 mask=(offs_k[:, None] < k_remaining) & (offs_bn[None, :] < N),
85 other=0.0,
86 )
87 accumulator += tl.dot(a, b, allow_tf32=False)
88 a_ptrs += BLOCK_SIZE_K * stride_ak
89 b_ptrs += BLOCK_SIZE_K * stride_bk
91 offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
92 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
93 c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
95 if BIAS_BROADCAST_M:
96 stride_im = 0
98 if BIAS_BROADCAST_N:
99 stride_in = 0
101 i_ptrs = i_ptr + stride_im * offs_cm[:, None] + stride_in * offs_cn[None, :]
103 if EVEN_M and EVEN_N:
104 bias = tl.load(i_ptrs)
105 else:
106 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
107 bias = tl.load(i_ptrs, mask=c_mask, other=0.0)
109 accumulator = accumulator * alpha + bias * beta
110 c = accumulator.to(bias.dtype)
112 if EVEN_M and EVEN_N:
113 tl.store(c_ptrs, c)
114 else:
115 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
116 tl.store(c_ptrs, c, mask=c_mask)
119def addmm(bias, mat1, mat2, *, beta=1, alpha=1):
120 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions"
121 assert broadcastable_to(
122 bias.shape, (mat1.shape[0], mat2.shape[1])
123 ), "Incompatible input shape"
124 M, K = mat1.shape
125 _, N = mat2.shape
127 logger.debug(
128 "GEMS_CAMBRICON ADDMM, [shape info]: [-, %s, %s, %s](batch, M, N, K), "
129 "[A column-major]: %s, [B column-major]: %s, [bias column-major]: %s",
130 M,
131 N,
132 K,
133 mat1.stride(0) == 1,
134 mat2.stride(0) == 1,
135 bias.stride(0) == 1,
136 )
137 mat1 = mat1.contiguous()
138 # mat2 = mat2.contiguous()
139 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype)
140 bias = bias.broadcast_to(out.shape)
142 grid = lambda META: (
143 triton.cdiv(M, META["BLOCK_SIZE_M"]),
144 triton.cdiv(N, META["BLOCK_SIZE_N"]),
145 )
146 with torch_device_fn.device(mat1.device):
147 addmm_kernel[grid](
148 mat1,
149 mat2,
150 bias,
151 out,
152 alpha,
153 beta,
154 M,
155 N,
156 K,
157 mat1.stride(0),
158 mat1.stride(1),
159 mat2.stride(0),
160 mat2.stride(1),
161 bias.stride(0),
162 bias.stride(1),
163 out.stride(0),
164 out.stride(1),
165 )
166 return out
169def addmm_out(bias, mat1, mat2, *, beta=1, alpha=1, out=None):
170 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions"
171 assert broadcastable_to(
172 bias.shape, (mat1.shape[0], mat2.shape[1])
173 ), "Incompatible input shape"
174 M, K = mat1.shape
175 _, N = mat2.shape
176 if out is None:
177 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype)
178 else:
179 assert out.shape == (M, N), "Incompatible output shape"
180 logger.debug(
181 "GEMS_CAMBRICON ADDMM_OUT, [shape info]: [-, %s, %s, %s](batch, M, N, K), "
182 "[A column-major]: %s, [B column-major]: %s, [bias column-major]: %s",
183 M,
184 N,
185 K,
186 mat1.stride(0) == 1,
187 mat2.stride(0) == 1,
188 bias.stride(0) == 1,
189 )
190 mat1 = mat1.contiguous()
191 bias = bias.broadcast_to(out.shape)
193 grid = lambda META: (
194 triton.cdiv(M, META["BLOCK_SIZE_M"]),
195 triton.cdiv(N, META["BLOCK_SIZE_N"]),
196 )
197 with torch_device_fn.device(mat1.device):
198 addmm_kernel[grid](
199 mat1,
200 mat2,
201 bias,
202 out,
203 alpha,
204 beta,
205 M,
206 N,
207 K,
208 mat1.stride(0),
209 mat1.stride(1),
210 mat2.stride(0),
211 mat2.stride(1),
212 bias.stride(0),
213 bias.stride(1),
214 out.stride(0),
215 out.stride(1),
216 )
217 return out