Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/addmv.py: 0%
61 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7# from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import broadcastable_to, libentry
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15def heur_block_n(args):
16 return triton.next_power_of_2(triton.cdiv(args["N"], 12))
19def heur_block_m(args):
20 import builtins
22 return builtins.min(triton.next_power_of_2(args["M"]), 4096)
25@libentry()
26# @triton.autotune(
27# configs=runtime.get_tuned_config("mv"),
28# key=["M", "N"],
29# )
30@triton.heuristics(
31 {
32 "BLOCK_N": heur_block_n,
33 "BLOCK_M": heur_block_m,
34 }
35)
36@triton.jit(do_not_specialize=["alpha", "beta"])
37def addmv_kernel(
38 A,
39 B,
40 Inp,
41 Out,
42 N,
43 M,
44 alpha,
45 beta,
46 stride_an,
47 stride_am,
48 stride_bm,
49 stride_in,
50 stride_outn,
51 BLOCK_N: tl.constexpr,
52 BLOCK_M: tl.constexpr,
53):
54 pid = tle.program_id(0)
55 offset_n = pid * BLOCK_N + tl.arange(0, BLOCK_N)[:, None]
56 offset_m = tl.arange(0, BLOCK_M)[None, :]
57 n_mask = offset_n < N
58 A_ptrs = A + offset_n * stride_an + offset_m * stride_am
59 B_ptrs = B + offset_m * stride_bm
60 acc = tl.zeros((BLOCK_N, BLOCK_M), dtype=tl.float32)
61 for m in range(0, M, BLOCK_M):
62 m_mask = m + offset_m < M
63 a = tl.load(A_ptrs, mask=n_mask & m_mask, other=0.0).to(tl.float32)
64 b = tl.load(B_ptrs, mask=m_mask, other=0.0).to(tl.float32)
65 acc += a * b
66 A_ptrs += BLOCK_M * stride_am
67 B_ptrs += BLOCK_M * stride_bm
69 acc = tl.sum(acc, axis=1)[:, None]
70 Inp_ptrs = Inp + offset_n * stride_in
71 inp = tl.load(Inp_ptrs, mask=n_mask, other=0.0).to(tl.float32)
72 Out_ptrs = Out + offset_n * stride_outn
73 out_block = acc * alpha + inp * beta
74 tl.store(Out_ptrs, out_block, mask=n_mask)
77def addmv(self, mat, vec, *, beta=1, alpha=1):
78 logger.debug("GEMS ADDMV")
79 assert mat.shape[1] == vec.shape[0], "incompatible dimensions"
80 assert broadcastable_to(self.shape, (mat.shape[0],)), "Incompatible self shape"
81 N, M = mat.shape
82 out = torch.empty((N,), device=mat.device, dtype=mat.dtype)
83 self = self.broadcast_to(out.shape)
84 grid = lambda META: (triton.cdiv(N, META["BLOCK_N"]),)
85 with torch_device_fn.device(mat.device):
86 addmv_kernel[grid](
87 mat,
88 vec,
89 self,
90 out,
91 N,
92 M,
93 alpha,
94 beta,
95 mat.stride(0),
96 mat.stride(1),
97 vec.stride(0),
98 self.stride(0),
99 out.stride(0),
100 )
101 return out
104def addmv_out(self, mat, vec, *, beta=1, alpha=1, out=None):
105 logger.debug("GEMS ADDMV OUT")
106 assert mat.shape[1] == vec.shape[0], "incompatible dimensions"
107 assert broadcastable_to(self.shape, (mat.shape[0],)), "Incompatible self shape"
108 N, M = mat.shape
109 if out is None:
110 out = torch.empty((N,), device=mat.device, dtype=mat.dtype)
111 else:
112 assert out.shape == (N,), "Incompatible output shape"
114 self = self.broadcast_to(out.shape)
115 grid = lambda META: (triton.cdiv(N, META["BLOCK_N"]),)
116 with torch_device_fn.device(mat.device):
117 addmv_kernel[grid](
118 mat,
119 vec,
120 self,
121 out,
122 N,
123 M,
124 alpha,
125 beta,
126 mat.stride(0),
127 mat.stride(1),
128 vec.stride(0),
129 self.stride(0),
130 out.stride(0),
131 )
132 return out