Coverage for src/flag_gems/runtime/backend/_cambricon/ops/mv.py: 0%
67 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
1import copy
2import logging
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
12from ..utils import MAX_NRAM_SIZE
14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
17def config_prune(configs, named_args, **kwargs):
18 M = named_args["M"]
19 configs_map = {}
20 pruned_configs = []
21 for config in configs:
22 kw = config.kwargs
23 BLOCK_M, BLOCK_N, num_warps, num_stages = (
24 kw["BLOCK_M"],
25 kw["BLOCK_N"],
26 config.num_warps,
27 config.num_stages,
28 )
29 doopt = BLOCK_N * M * 4 * 3 < MAX_NRAM_SIZE
30 if doopt:
31 config = copy.deepcopy(config)
32 BLOCK_M = config.kwargs["BLOCK_M"] = M
33 num_stages = config.num_stages = 1
34 elif BLOCK_M >= M:
35 continue
36 key = (BLOCK_M, BLOCK_N, num_warps, num_stages)
37 # Only keep one config for the same key
38 configs_map.setdefault(key, config)
39 pruned_configs = []
40 for k, v in configs_map.items():
41 pruned_configs.append(v)
42 return pruned_configs
45@libentry()
46@triton.autotune(
47 configs=runtime.get_tuned_config("mv"),
48 key=["M", "N"],
49 prune_configs_by={"early_config_prune": config_prune},
50)
51@triton.heuristics(
52 values={
53 "ONE_TILE_PER_CTA": lambda args: args["M"] <= args["BLOCK_M"],
54 },
55)
56@triton.jit
57def mv_kernel(
58 A,
59 B,
60 C,
61 N,
62 M,
63 stride_an,
64 stride_am,
65 stride_bm,
66 stride_cn,
67 BLOCK_N: tl.constexpr,
68 BLOCK_M: tl.constexpr,
69 ONE_TILE_PER_CTA: tl.constexpr,
70):
71 pid = tl.program_id(0)
72 offset_n = pid * BLOCK_N + tl.arange(0, BLOCK_N)[:, None]
73 offset_m = tl.arange(0, BLOCK_M)[None, :]
74 n_mask = offset_n < N
75 A_ptrs = A + offset_n * stride_an + offset_m * stride_am
76 B_ptrs = B + offset_m * stride_bm
77 if ONE_TILE_PER_CTA:
78 a = tl.load(A_ptrs, mask=n_mask, other=0.0).to(tl.float32)
79 b = tl.load(B_ptrs).to(tl.float32)
80 acc = tl.sum(a * b, axis=1)
81 C_ptrs = C + offset_n * stride_cn
82 tl.store(C_ptrs, acc[:, None], mask=n_mask)
83 else:
84 acc = tl.zeros((BLOCK_N, BLOCK_M), dtype=tl.float32)
85 for m in range(0, M, BLOCK_M):
86 m_mask = m + offset_m < M
87 a = tl.load(A_ptrs, mask=n_mask & m_mask, other=0.0).to(tl.float32)
88 b = tl.load(B_ptrs, mask=m_mask, other=0.0).to(tl.float32)
89 acc += a * b
90 A_ptrs += BLOCK_M * stride_am
91 B_ptrs += BLOCK_M * stride_bm
92 acc = tl.sum(acc, axis=1)
93 C_ptrs = C + offset_n * stride_cn
94 tl.store(C_ptrs, acc[:, None], mask=n_mask)
97def mv(inp, vec):
98 logger.debug("GEMS_CAMBRICON MV")
99 assert inp.shape[1] == vec.shape[0], "incompatible dimensions"
100 N, M = inp.shape
101 out = torch.empty((N,), device=inp.device, dtype=inp.dtype)
102 grid = lambda META: (triton.cdiv(N, META["BLOCK_N"]),)
103 with torch_device_fn.device(inp.device):
104 mv_kernel[grid](
105 inp,
106 vec,
107 out,
108 N,
109 M,
110 inp.stride(0),
111 inp.stride(1),
112 vec.stride(0),
113 out.stride(0),
114 )
115 return out