Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/addmm.py: 0%
85 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1import logging
2import os
4import torch
5import triton
6import triton.language as tl
8# from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import broadcastable_to, libentry
11from flag_gems.utils import triton_lang_extension as tle
13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
16autotune_decorator = triton.autotune(
17 configs=[],
18 generate_configs="addmm",
19 key=["M", "N", "K"],
20)
23KLX_USE_AUTOTUNE = os.environ.get("KLX_USE_AUTOTUNE", "1") == "1"
25if not KLX_USE_AUTOTUNE:
27 def heur_block_m(args):
28 M = args["M"]
29 if M == 1:
30 return 2
31 if M <= 32:
32 return M
34 return 128
36 def heur_block_n(args):
37 N = args["N"]
38 if N == 1:
39 return 2
40 if N <= 32:
41 return N
42 return 128
44 def heur_block_k(args):
45 K = args["K"]
46 return min(K, 128)
48 autotune_decorator = triton.heuristics(
49 {
50 "BLOCK_SIZE_M": heur_block_m,
51 "BLOCK_SIZE_N": heur_block_n,
52 "BLOCK_SIZE_K": heur_block_k,
53 }
54 )
57@libentry()
58@autotune_decorator
59@triton.jit(do_not_specialize=["alpha", "beta"])
60def addmm_kernel(
61 a_ptr,
62 b_ptr,
63 i_ptr,
64 c_ptr,
65 alpha,
66 beta,
67 M,
68 N,
69 K,
70 stride_am,
71 stride_ak,
72 stride_bk,
73 stride_bn,
74 stride_im,
75 stride_in,
76 stride_cm,
77 stride_cn,
78 BLOCK_SIZE_M: tl.constexpr,
79 BLOCK_SIZE_N: tl.constexpr,
80 BLOCK_SIZE_K: tl.constexpr,
81):
82 pid_m = tle.program_id(0)
83 pid_n = tle.program_id(1)
85 offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
86 offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
87 offs_k = tl.arange(0, BLOCK_SIZE_K)
88 a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
89 b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
91 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
92 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
93 a = tl.load(a_ptrs)
94 b = tl.load(b_ptrs)
95 accumulator += tl.dot(a, b, allow_tf32=False)
96 a_ptrs += BLOCK_SIZE_K * stride_ak
97 b_ptrs += BLOCK_SIZE_K * stride_bk
99 offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
100 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
101 c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
102 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
103 i_ptrs = i_ptr + stride_im * offs_cm[:, None] + stride_in * offs_cn[None, :]
104 bias = tl.load(i_ptrs, mask=c_mask, other=0.0)
106 accumulator = accumulator * alpha + bias * beta
107 c = accumulator.to(bias.dtype)
108 tl.store(c_ptrs, c, mask=c_mask)
111def addmm(bias, mat1, mat2, *, beta=1.0, alpha=1.0):
112 logger.debug("GEMS ADDMM")
113 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions"
114 assert broadcastable_to(
115 bias.shape, (mat1.shape[0], mat2.shape[1])
116 ), "Incompatible input shape"
117 M, K = mat1.shape
118 _, N = mat2.shape
120 mat1 = mat1.contiguous()
121 # mat2 = mat2.contiguous()
122 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype)
123 bias = bias.broadcast_to(out.shape)
125 grid = lambda META: (
126 triton.cdiv(M, META["BLOCK_SIZE_M"]),
127 triton.cdiv(N, META["BLOCK_SIZE_N"]),
128 )
129 with torch_device_fn.device(mat1.device):
130 addmm_kernel[grid](
131 mat1,
132 mat2,
133 bias,
134 out,
135 alpha,
136 beta,
137 M,
138 N,
139 K,
140 mat1.stride(0),
141 mat1.stride(1),
142 mat2.stride(0),
143 mat2.stride(1),
144 bias.stride(0),
145 bias.stride(1),
146 out.stride(0),
147 out.stride(1),
148 )
149 return out
152def addmm_out(bias, mat1, mat2, *, beta=1.0, alpha=1.0, out=None):
153 logger.debug("GEMS ADDMM OUT")
154 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions"
155 assert broadcastable_to(
156 bias.shape, (mat1.shape[0], mat2.shape[1])
157 ), "Incompatible input shape"
158 M, K = mat1.shape
159 _, N = mat2.shape
160 if out is None:
161 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype)
162 else:
163 assert out.shape == (M, N), "Incompatible output shape"
165 mat1 = mat1.contiguous()
166 bias = bias.broadcast_to(out.shape)
168 grid = lambda META: (
169 triton.cdiv(M, META["BLOCK_SIZE_M"]),
170 triton.cdiv(N, META["BLOCK_SIZE_N"]),
171 )
172 with torch_device_fn.device(mat1.device):
173 addmm_kernel[grid](
174 mat1,
175 mat2,
176 bias,
177 out,
178 alpha,
179 beta,
180 M,
181 N,
182 K,
183 mat1.stride(0),
184 mat1.stride(1),
185 mat2.stride(0),
186 mat2.stride(1),
187 bias.stride(0),
188 bias.stride(1),
189 out.stride(0),
190 out.stride(1),
191 )
192 return out