Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/mv.py: 0%
62 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
1import logging
2import os
4import torch
5import triton
6import triton.language as tl
8# from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as tle
13from .mm import mm
15logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
18def heur_block_n(args):
19 return triton.next_power_of_2(triton.cdiv(args["N"], 12))
22def heur_block_m(args):
23 import builtins
25 return builtins.min(triton.next_power_of_2(args["M"]), 4096)
28@libentry()
29# @triton.autotune(
30# configs=runtime.get_tuned_config("mv"),
31# key=["M", "N"],
32# )
33@triton.heuristics(
34 {
35 "BLOCK_N": heur_block_n,
36 "BLOCK_M": heur_block_m,
37 }
38)
39@triton.jit
40def mv_kernel(
41 A,
42 B,
43 C,
44 N: tl.constexpr,
45 M: tl.constexpr,
46 stride_an: tl.constexpr,
47 stride_am: tl.constexpr,
48 stride_bm: tl.constexpr,
49 stride_cn: tl.constexpr,
50 BLOCK_N: tl.constexpr,
51 BLOCK_M: tl.constexpr,
52 buffer_size_limit: tl.constexpr, # NOTE: `constexpr` so it can be used as a shape value.
53):
54 pid = tle.program_id(0)
55 offset_n = pid * BLOCK_N + tl.arange(0, BLOCK_N)[:, None]
56 offset_m = tl.arange(0, BLOCK_M)[None, :]
57 n_mask = offset_n < N
58 A_ptrs = A + offset_n * stride_an + offset_m * stride_am
59 B_ptrs = B + offset_m * stride_bm
60 acc = tl.zeros((BLOCK_N, BLOCK_M), dtype=tl.float32)
61 for m in range(0, M, BLOCK_M):
62 m_mask = m + offset_m < M
63 a = tl.load(A_ptrs, mask=n_mask & m_mask, other=0.0).to(tl.float32)
64 b = tl.load(B_ptrs, mask=m_mask, other=0.0).to(tl.float32)
65 acc += a * b
66 A_ptrs += BLOCK_M * stride_am
67 B_ptrs += BLOCK_M * stride_bm
69 acc = tl.sum(acc, axis=1)
70 C_ptrs = C + offset_n * stride_cn
71 tl.store(C_ptrs, acc[:, None], mask=n_mask)
74def mv(inp, vec):
75 logger.debug("GEMS MV")
76 assert inp.shape[1] == vec.shape[0], "incompatible dimensions"
77 N, M = inp.shape
78 # TODO: fix autotune config has no item
79 if M == 5333 and N == 497:
80 return mv_cluster(inp, vec)
82 out = torch.empty((N,), device=inp.device, dtype=inp.dtype)
83 grid = lambda META: (triton.cdiv(N, META["BLOCK_N"]),)
84 with torch_device_fn.device(inp.device):
85 if M == 1:
86 mv_kernel[grid](
87 inp,
88 vec,
89 out,
90 N,
91 M,
92 inp.stride(0),
93 inp.stride(1),
94 vec.stride(0),
95 out.stride(0),
96 buffer_size_limit=256,
97 )
98 else:
99 os.environ["XMLIR_MATMUL_FAST_MODE"] = "1"
100 vec = vec[:, None]
101 out = mm(inp, vec)
102 out = out.squeeze()
103 del os.environ["XMLIR_MATMUL_FAST_MODE"]
104 return out
107def mv_cluster(inp, vec):
108 logger.debug("GEMS MV")
109 assert inp.shape[1] == vec.shape[0], "incompatible dimensions"
110 N, M = inp.shape
111 out = torch.empty((N,), device=inp.device, dtype=inp.dtype)
112 grid = lambda META: (triton.cdiv(N, META["BLOCK_N"]),)
113 with torch_device_fn.device(inp.device):
114 mv_kernel[grid](
115 inp,
116 vec,
117 out,
118 N,
119 M,
120 inp.stride(0),
121 inp.stride(1),
122 vec.stride(0),
123 out.stride(0),
124 buffer_size_limit=256,
125 )
126 return out