Coverage for src/flag_gems/ops/mv.py: 56%
39 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger(__name__)
15@libentry()
16@triton.autotune(
17 configs=runtime.get_tuned_config("mv"),
18 key=["M", "N"],
19)
20@triton.jit
21def mv_kernel(
22 A,
23 B,
24 C,
25 N,
26 M,
27 stride_an,
28 stride_am,
29 stride_bm,
30 stride_cn,
31 BLOCK_N: tl.constexpr,
32 BLOCK_M: tl.constexpr,
33):
34 pid = tle.program_id(0)
35 offset_n = pid * BLOCK_N + tl.arange(0, BLOCK_N)[:, None]
36 offset_m = tl.arange(0, BLOCK_M)[None, :]
37 n_mask = offset_n < N
38 A_ptrs = A + offset_n * stride_an + offset_m * stride_am
39 B_ptrs = B + offset_m * stride_bm
40 acc = tl.zeros((BLOCK_N, BLOCK_M), dtype=tl.float32)
41 for m in range(0, M, BLOCK_M):
42 m_mask = m + offset_m < M
43 a = tl.load(A_ptrs, mask=n_mask & m_mask, other=0.0).to(tl.float32)
44 b = tl.load(B_ptrs, mask=m_mask, other=0.0).to(tl.float32)
45 acc += a * b
46 A_ptrs += BLOCK_M * stride_am
47 B_ptrs += BLOCK_M * stride_bm
49 acc = tl.sum(acc, axis=1)
50 C_ptrs = C + offset_n * stride_cn
51 tl.store(C_ptrs, acc[:, None], mask=n_mask)
54def mv(inp, vec):
55 logger.debug("GEMS MV")
56 assert inp.shape[1] == vec.shape[0], "incompatible dimensions"
57 N, M = inp.shape
58 out = torch.empty((N,), device=inp.device, dtype=inp.dtype)
59 grid = lambda META: (triton.cdiv(N, META["BLOCK_N"]),)
60 with torch_device_fn.device(inp.device):
61 mv_kernel[grid](
62 inp,
63 vec,
64 out,
65 N,
66 M,
67 inp.stride(0),
68 inp.stride(1),
69 vec.stride(0),
70 out.stride(0),
71 )
72 return out