Coverage for src/flag_gems/experimental_ops/mv.py: 0%

58 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def mv_kernel( 

8 A_ptr, # *Pointer* to matrix A [M, N] 

9 x_ptr, # *Pointer* to vector x [N] 

10 y_ptr, # *Pointer* to output vector y [M] 

11 M, # rows of A 

12 N, # cols of A (and size of x) 

13 stride_am, # stride for A along M (row stride) 

14 stride_an, # stride for A along N (col stride) 

15 stride_xn, # stride for x along N 

16 stride_ym, # stride for y along M 

17 BLOCK_N: tl.constexpr, # tile size along N 

18): 

19 pid_m = tl.program_id(axis=0) 

20 offs_n = tl.arange(0, BLOCK_N) 

21 acc = tl.zeros((), dtype=tl.float32) 

22 

23 row_ptr = A_ptr + pid_m * stride_am 

24 

25 for n0 in range(0, N, BLOCK_N): 

26 idx_n = n0 + offs_n 

27 mask = idx_n < N 

28 a = tl.load(row_ptr + idx_n * stride_an, mask=mask, other=0.0) 

29 x = tl.load(x_ptr + idx_n * stride_xn, mask=mask, other=0.0) 

30 # accumulate in fp32 for better precision 

31 acc += tl.sum(a.to(tl.float32) * x.to(tl.float32), axis=0) 

32 

33 tl.store(y_ptr + pid_m * stride_ym, acc) 

34 

35 

36def _launch_mv_kernel(A: torch.Tensor, x: torch.Tensor, y: torch.Tensor): 

37 M, N = A.shape 

38 assert x.numel() == N 

39 grid = (M,) 

40 mv_kernel[grid]( 

41 A, 

42 x, 

43 y, 

44 M, 

45 N, 

46 A.stride(0), 

47 A.stride(1), 

48 x.stride(0), 

49 y.stride(0), 

50 BLOCK_N=256, 

51 num_warps=4, 

52 num_stages=2, 

53 ) 

54 

55 

56def mv(A: torch.Tensor, x: torch.Tensor): 

57 # Validate inputs 

58 assert isinstance(A, torch.Tensor) and isinstance( 

59 x, torch.Tensor 

60 ), "Inputs must be tensors" 

61 assert A.ndim == 2 and x.ndim == 1, "mv expects A: 2D tensor and x: 1D tensor" 

62 assert A.shape[1] == x.shape[0], "Incompatible dimensions for mv" 

63 assert ( 

64 A.is_cuda and x.is_cuda and A.device == x.device 

65 ), "All tensors must be on the same CUDA device" 

66 

67 # Determine output dtype following PyTorch's type promotion 

68 out_dtype = torch.result_type(A, x) 

69 M = A.shape[0] 

70 if M == 0: 

71 return torch.empty((0,), device=A.device, dtype=out_dtype) 

72 

73 # Prepare tensors (dtype + contiguous) 

74 A_ = A.to(out_dtype).contiguous() 

75 x_ = x.to(out_dtype).contiguous() 

76 y = torch.empty((M,), device=A.device, dtype=out_dtype) 

77 y_ = y.contiguous() 

78 

79 _launch_mv_kernel(A_, x_, y_) 

80 

81 if y_.data_ptr() != y.data_ptr(): 

82 y.copy_(y_) 

83 return y 

84 

85 

86def mv_out(A: torch.Tensor, x: torch.Tensor, out: torch.Tensor): 

87 # Validate inputs 

88 assert ( 

89 isinstance(A, torch.Tensor) 

90 and isinstance(x, torch.Tensor) 

91 and isinstance(out, torch.Tensor) 

92 ), "Inputs must be tensors" 

93 assert ( 

94 A.ndim == 2 and x.ndim == 1 and out.ndim == 1 

95 ), "Shapes must be A: [M, N], x: [N], out: [M]" 

96 assert A.shape[1] == x.shape[0], "Incompatible dimensions for mv.out" 

97 assert out.shape[0] == A.shape[0], "Output shape must match rows of A" 

98 assert A.is_cuda and x.is_cuda and out.is_cuda, "All tensors must be CUDA tensors" 

99 assert A.device == x.device == out.device, "All tensors must be on the same device" 

100 

101 # Execute in the dtype of out (PyTorch .out usually determines dtype by out) 

102 compute_dtype = out.dtype 

103 M = A.shape[0] 

104 if M == 0: 

105 return out 

106 

107 A_ = A.to(compute_dtype).contiguous() 

108 x_ = x.to(compute_dtype).contiguous() 

109 

110 if out.is_contiguous(): 

111 _launch_mv_kernel(A_, x_, out) 

112 return out 

113 else: 

114 y_tmp = torch.empty_like(out, memory_format=torch.contiguous_format) 

115 _launch_mv_kernel(A_, x_, y_tmp) 

116 out.copy_(y_tmp) 

117 return out