Coverage for src/flag_gems/experimental_ops/mv.py: 0%
58 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def mv_kernel(
8 A_ptr, # *Pointer* to matrix A [M, N]
9 x_ptr, # *Pointer* to vector x [N]
10 y_ptr, # *Pointer* to output vector y [M]
11 M, # rows of A
12 N, # cols of A (and size of x)
13 stride_am, # stride for A along M (row stride)
14 stride_an, # stride for A along N (col stride)
15 stride_xn, # stride for x along N
16 stride_ym, # stride for y along M
17 BLOCK_N: tl.constexpr, # tile size along N
18):
19 pid_m = tl.program_id(axis=0)
20 offs_n = tl.arange(0, BLOCK_N)
21 acc = tl.zeros((), dtype=tl.float32)
23 row_ptr = A_ptr + pid_m * stride_am
25 for n0 in range(0, N, BLOCK_N):
26 idx_n = n0 + offs_n
27 mask = idx_n < N
28 a = tl.load(row_ptr + idx_n * stride_an, mask=mask, other=0.0)
29 x = tl.load(x_ptr + idx_n * stride_xn, mask=mask, other=0.0)
30 # accumulate in fp32 for better precision
31 acc += tl.sum(a.to(tl.float32) * x.to(tl.float32), axis=0)
33 tl.store(y_ptr + pid_m * stride_ym, acc)
36def _launch_mv_kernel(A: torch.Tensor, x: torch.Tensor, y: torch.Tensor):
37 M, N = A.shape
38 assert x.numel() == N
39 grid = (M,)
40 mv_kernel[grid](
41 A,
42 x,
43 y,
44 M,
45 N,
46 A.stride(0),
47 A.stride(1),
48 x.stride(0),
49 y.stride(0),
50 BLOCK_N=256,
51 num_warps=4,
52 num_stages=2,
53 )
56def mv(A: torch.Tensor, x: torch.Tensor):
57 # Validate inputs
58 assert isinstance(A, torch.Tensor) and isinstance(
59 x, torch.Tensor
60 ), "Inputs must be tensors"
61 assert A.ndim == 2 and x.ndim == 1, "mv expects A: 2D tensor and x: 1D tensor"
62 assert A.shape[1] == x.shape[0], "Incompatible dimensions for mv"
63 assert (
64 A.is_cuda and x.is_cuda and A.device == x.device
65 ), "All tensors must be on the same CUDA device"
67 # Determine output dtype following PyTorch's type promotion
68 out_dtype = torch.result_type(A, x)
69 M = A.shape[0]
70 if M == 0:
71 return torch.empty((0,), device=A.device, dtype=out_dtype)
73 # Prepare tensors (dtype + contiguous)
74 A_ = A.to(out_dtype).contiguous()
75 x_ = x.to(out_dtype).contiguous()
76 y = torch.empty((M,), device=A.device, dtype=out_dtype)
77 y_ = y.contiguous()
79 _launch_mv_kernel(A_, x_, y_)
81 if y_.data_ptr() != y.data_ptr():
82 y.copy_(y_)
83 return y
86def mv_out(A: torch.Tensor, x: torch.Tensor, out: torch.Tensor):
87 # Validate inputs
88 assert (
89 isinstance(A, torch.Tensor)
90 and isinstance(x, torch.Tensor)
91 and isinstance(out, torch.Tensor)
92 ), "Inputs must be tensors"
93 assert (
94 A.ndim == 2 and x.ndim == 1 and out.ndim == 1
95 ), "Shapes must be A: [M, N], x: [N], out: [M]"
96 assert A.shape[1] == x.shape[0], "Incompatible dimensions for mv.out"
97 assert out.shape[0] == A.shape[0], "Output shape must match rows of A"
98 assert A.is_cuda and x.is_cuda and out.is_cuda, "All tensors must be CUDA tensors"
99 assert A.device == x.device == out.device, "All tensors must be on the same device"
101 # Execute in the dtype of out (PyTorch .out usually determines dtype by out)
102 compute_dtype = out.dtype
103 M = A.shape[0]
104 if M == 0:
105 return out
107 A_ = A.to(compute_dtype).contiguous()
108 x_ = x.to(compute_dtype).contiguous()
110 if out.is_contiguous():
111 _launch_mv_kernel(A_, x_, out)
112 return out
113 else:
114 y_tmp = torch.empty_like(out, memory_format=torch.contiguous_format)
115 _launch_mv_kernel(A_, x_, y_tmp)
116 out.copy_(y_tmp)
117 return out