Coverage for src/flag_gems/runtime/backend/_metax/ops/addmm.py: 0%
56 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import broadcastable_to, libentry, libtuner
11from flag_gems.utils import triton_lang_extension as tle
13logger = logging.getLogger("flag_gems." + __name__)
16@libentry()
17@libtuner(
18 configs=runtime.get_tuned_config("addmm"),
19 key=["M", "N", "K"],
20)
21@triton.heuristics(
22 {
23 "UPGRADE": lambda args: math.ceil(
24 (args["M"] * args["N"]) / (args["BLOCK_SIZE_M"] * args["BLOCK_SIZE_N"])
25 ).bit_length()
26 > 32,
27 }
28)
29@triton.jit(do_not_specialize=["alpha", "beta"])
30def addmm_kernel(
31 a_ptr,
32 b_ptr,
33 i_ptr,
34 c_ptr,
35 alpha,
36 beta,
37 M,
38 N,
39 K,
40 stride_am,
41 stride_ak,
42 stride_bk,
43 stride_bn,
44 stride_im,
45 stride_in,
46 stride_cm,
47 stride_cn,
48 BLOCK_SIZE_M: tl.constexpr,
49 BLOCK_SIZE_N: tl.constexpr,
50 BLOCK_SIZE_K: tl.constexpr,
51 UPGRADE: tl.constexpr,
52):
53 if UPGRADE:
54 pid_m = tle.program_id(0)
55 pid_n = tle.program_id(1)
56 else:
57 pid_m = tl.program_id(0)
58 pid_n = tl.program_id(1)
60 offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
61 offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
62 offs_k = tl.arange(0, BLOCK_SIZE_K)
63 a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
64 b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
66 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
67 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
68 a = tl.load(
69 a_ptrs,
70 mask=(offs_am[:, None] < M) & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
71 other=0.0,
72 )
73 b = tl.load(
74 b_ptrs,
75 mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_bn[None, :] < N),
76 other=0.0,
77 )
78 accumulator += tl.dot(a, b, allow_tf32=False)
79 a_ptrs += BLOCK_SIZE_K * stride_ak
80 b_ptrs += BLOCK_SIZE_K * stride_bk
82 offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
83 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
84 c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
85 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
86 i_ptrs = i_ptr + stride_im * offs_cm[:, None] + stride_in * offs_cn[None, :]
87 bias = tl.load(i_ptrs, mask=c_mask, other=0.0)
89 accumulator = accumulator * alpha + bias * beta
90 c = accumulator.to(bias.dtype)
91 tl.store(c_ptrs, c, mask=c_mask)
94def addmm(bias, mat1, mat2, *, beta=1, alpha=1):
95 logger.debug("METAX GEMS ADDMM")
96 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions"
97 assert broadcastable_to(
98 bias.shape, (mat1.shape[0], mat2.shape[1])
99 ), "Incompatible input shape"
100 M, K = mat1.shape
101 _, N = mat2.shape
103 logger.debug(
104 "METAX GEMS ADDMM, [shape info]: [-, %s, %s, %s](batch, M, N, K), "
105 "[A column-major]: %s, [B column-major]: %s, [bias column-major]: %s",
106 M,
107 N,
108 K,
109 mat1.stride(0) == 1,
110 mat2.stride(0) == 1,
111 bias.stride(0) == 1,
112 )
114 mat1 = mat1.contiguous()
115 mat2 = mat2.contiguous()
116 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype)
117 bias = bias.broadcast_to(out.shape).contiguous()
119 grid = lambda META: (
120 triton.cdiv(M, META["BLOCK_SIZE_M"]),
121 triton.cdiv(N, META["BLOCK_SIZE_N"]),
122 )
123 with torch_device_fn.device(mat1.device):
124 addmm_kernel[grid](
125 mat1,
126 mat2,
127 bias,
128 out,
129 alpha,
130 beta,
131 M,
132 N,
133 K,
134 mat1.stride(0),
135 mat1.stride(1),
136 mat2.stride(0),
137 mat2.stride(1),
138 bias.stride(0),
139 bias.stride(1),
140 out.stride(0),
141 out.stride(1),
142 )
143 return out