Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/addmm.py: 0%
63 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7# from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import broadcastable_to, libentry
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15@libentry()
16@triton.autotune(
17 configs=[],
18 generate_configs="addmm",
19 key=["M", "N", "K"],
20)
21@triton.jit(do_not_specialize=["alpha", "beta"])
22def addmm_kernel(
23 a_ptr,
24 b_ptr,
25 i_ptr,
26 c_ptr,
27 alpha,
28 beta,
29 M,
30 N,
31 K,
32 stride_am,
33 stride_ak,
34 stride_bk,
35 stride_bn,
36 stride_im,
37 stride_in,
38 stride_cm,
39 stride_cn,
40 BLOCK_SIZE_M: tl.constexpr,
41 BLOCK_SIZE_N: tl.constexpr,
42 BLOCK_SIZE_K: tl.constexpr,
43):
44 pid_m = tle.program_id(0)
45 pid_n = tle.program_id(1)
47 offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
48 offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
49 offs_k = tl.arange(0, BLOCK_SIZE_K)
50 a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
51 b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
53 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
54 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
55 a = tl.load(a_ptrs)
56 b = tl.load(b_ptrs)
57 accumulator += tl.dot(a, b, allow_tf32=False)
58 a_ptrs += BLOCK_SIZE_K * stride_ak
59 b_ptrs += BLOCK_SIZE_K * stride_bk
61 offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
62 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
63 c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
64 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
65 i_ptrs = i_ptr + stride_im * offs_cm[:, None] + stride_in * offs_cn[None, :]
66 bias = tl.load(i_ptrs, mask=c_mask, other=0.0)
68 accumulator = accumulator * alpha + bias * beta
69 c = accumulator.to(bias.dtype)
70 tl.store(c_ptrs, c, mask=c_mask)
73def addmm(bias, mat1, mat2, *, beta=1.0, alpha=1.0):
74 logger.debug("GEMS ADDMM")
75 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions"
76 assert broadcastable_to(
77 bias.shape, (mat1.shape[0], mat2.shape[1])
78 ), "Incompatible input shape"
79 M, K = mat1.shape
80 _, N = mat2.shape
82 mat1 = mat1.contiguous()
83 # mat2 = mat2.contiguous()
84 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype)
85 bias = bias.broadcast_to(out.shape)
87 grid = lambda META: (
88 triton.cdiv(M, META["BLOCK_SIZE_M"]),
89 triton.cdiv(N, META["BLOCK_SIZE_N"]),
90 )
91 with torch_device_fn.device(mat1.device):
92 addmm_kernel[grid](
93 mat1,
94 mat2,
95 bias,
96 out,
97 alpha,
98 beta,
99 M,
100 N,
101 K,
102 mat1.stride(0),
103 mat1.stride(1),
104 mat2.stride(0),
105 mat2.stride(1),
106 bias.stride(0),
107 bias.stride(1),
108 out.stride(0),
109 out.stride(1),
110 )
111 return out
114def addmm_out(bias, mat1, mat2, *, beta=1.0, alpha=1.0, out=None):
115 logger.debug("GEMS ADDMM OUT")
116 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions"
117 assert broadcastable_to(
118 bias.shape, (mat1.shape[0], mat2.shape[1])
119 ), "Incompatible input shape"
120 M, K = mat1.shape
121 _, N = mat2.shape
122 if out is None:
123 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype)
124 else:
125 assert out.shape == (M, N), "Incompatible output shape"
127 mat1 = mat1.contiguous()
128 bias = bias.broadcast_to(out.shape)
130 grid = lambda META: (
131 triton.cdiv(M, META["BLOCK_SIZE_M"]),
132 triton.cdiv(N, META["BLOCK_SIZE_N"]),
133 )
134 with torch_device_fn.device(mat1.device):
135 addmm_kernel[grid](
136 mat1,
137 mat2,
138 bias,
139 out,
140 alpha,
141 beta,
142 M,
143 N,
144 K,
145 mat1.stride(0),
146 mat1.stride(1),
147 mat2.stride(0),
148 mat2.stride(1),
149 bias.stride(0),
150 bias.stride(1),
151 out.stride(0),
152 out.stride(1),
153 )
154 return out