Coverage for src/flag_gems/ops/addmv.py: 63%
57 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import broadcastable_to, libentry
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger(__name__)
15@libentry()
16@triton.autotune(
17 configs=runtime.get_tuned_config("mv"),
18 key=["M", "N"],
19)
20@triton.jit(do_not_specialize=["alpha", "beta"])
21def addmv_kernel(
22 A,
23 B,
24 Inp,
25 Out,
26 N,
27 M,
28 alpha,
29 beta,
30 stride_an,
31 stride_am,
32 stride_bm,
33 stride_in,
34 stride_outn,
35 BLOCK_N: tl.constexpr,
36 BLOCK_M: tl.constexpr,
37):
38 pid = tle.program_id(0)
39 offset_n = pid * BLOCK_N + tl.arange(0, BLOCK_N)[:, None]
40 offset_m = tl.arange(0, BLOCK_M)[None, :]
41 n_mask = offset_n < N
42 A_ptrs = A + offset_n * stride_an + offset_m * stride_am
43 B_ptrs = B + offset_m * stride_bm
44 acc = tl.zeros((BLOCK_N, BLOCK_M), dtype=tl.float32)
45 for m in range(0, M, BLOCK_M):
46 m_mask = m + offset_m < M
47 a = tl.load(A_ptrs, mask=n_mask & m_mask, other=0.0).to(tl.float32)
48 b = tl.load(B_ptrs, mask=m_mask, other=0.0).to(tl.float32)
49 acc += a * b
50 A_ptrs += BLOCK_M * stride_am
51 B_ptrs += BLOCK_M * stride_bm
53 acc = tl.sum(acc, axis=1)[:, None]
54 Inp_ptrs = Inp + offset_n * stride_in
55 inp = tl.load(Inp_ptrs, mask=n_mask, other=0.0).to(tl.float32)
56 Out_ptrs = Out + offset_n * stride_outn
57 out_block = acc * alpha + inp * beta
58 tl.store(Out_ptrs, out_block, mask=n_mask)
61def addmv(self, mat, vec, *, beta=1, alpha=1):
62 logger.debug("GEMS ADDMV")
63 assert mat.shape[1] == vec.shape[0], "incompatible dimensions"
64 assert broadcastable_to(self.shape, (mat.shape[0],)), "Incompatible self shape"
65 N, M = mat.shape
66 out = torch.empty((N,), device=mat.device, dtype=mat.dtype)
67 self = self.broadcast_to(out.shape)
68 grid = lambda META: (triton.cdiv(N, META["BLOCK_N"]),)
69 with torch_device_fn.device(mat.device):
70 addmv_kernel[grid](
71 mat,
72 vec,
73 self,
74 out,
75 N,
76 M,
77 alpha,
78 beta,
79 mat.stride(0),
80 mat.stride(1),
81 vec.stride(0),
82 self.stride(0),
83 out.stride(0),
84 )
85 return out
88def addmv_out(self, mat, vec, *, beta=1, alpha=1, out=None):
89 logger.debug("GEMS ADDMV OUT")
90 assert mat.shape[1] == vec.shape[0], "incompatible dimensions"
91 assert broadcastable_to(self.shape, (mat.shape[0],)), "Incompatible self shape"
92 N, M = mat.shape
93 if out is None:
94 out = torch.empty((N,), device=mat.device, dtype=mat.dtype)
95 else:
96 assert out.shape == (N,), "Incompatible output shape"
98 self = self.broadcast_to(out.shape)
99 grid = lambda META: (triton.cdiv(N, META["BLOCK_N"]),)
100 with torch_device_fn.device(mat.device):
101 addmv_kernel[grid](
102 mat,
103 vec,
104 self,
105 out,
106 N,
107 M,
108 alpha,
109 beta,
110 mat.stride(0),
111 mat.stride(1),
112 vec.stride(0),
113 self.stride(0),
114 out.stride(0),
115 )
116 return out