Coverage for src/flag_gems/ops/addmv.py: 63%

57 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-25 02:48 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import broadcastable_to, libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15@libentry() 

16@triton.autotune( 

17 configs=runtime.get_tuned_config("mv"), 

18 key=["M", "N"], 

19) 

20@triton.jit(do_not_specialize=["alpha", "beta"]) 

21def addmv_kernel( 

22 A, 

23 B, 

24 Inp, 

25 Out, 

26 N, 

27 M, 

28 alpha, 

29 beta, 

30 stride_an, 

31 stride_am, 

32 stride_bm, 

33 stride_in, 

34 stride_outn, 

35 BLOCK_N: tl.constexpr, 

36 BLOCK_M: tl.constexpr, 

37): 

38 pid = tle.program_id(0) 

39 offset_n = pid * BLOCK_N + tl.arange(0, BLOCK_N)[:, None] 

40 offset_m = tl.arange(0, BLOCK_M)[None, :] 

41 n_mask = offset_n < N 

42 A_ptrs = A + offset_n * stride_an + offset_m * stride_am 

43 B_ptrs = B + offset_m * stride_bm 

44 acc = tl.zeros((BLOCK_N, BLOCK_M), dtype=tl.float32) 

45 for m in range(0, M, BLOCK_M): 

46 m_mask = m + offset_m < M 

47 a = tl.load(A_ptrs, mask=n_mask & m_mask, other=0.0).to(tl.float32) 

48 b = tl.load(B_ptrs, mask=m_mask, other=0.0).to(tl.float32) 

49 acc += a * b 

50 A_ptrs += BLOCK_M * stride_am 

51 B_ptrs += BLOCK_M * stride_bm 

52 

53 acc = tl.sum(acc, axis=1)[:, None] 

54 Inp_ptrs = Inp + offset_n * stride_in 

55 inp = tl.load(Inp_ptrs, mask=n_mask, other=0.0).to(tl.float32) 

56 Out_ptrs = Out + offset_n * stride_outn 

57 out_block = acc * alpha + inp * beta 

58 tl.store(Out_ptrs, out_block, mask=n_mask) 

59 

60 

61def addmv(self, mat, vec, *, beta=1, alpha=1): 

62 logger.debug("GEMS ADDMV") 

63 assert mat.shape[1] == vec.shape[0], "incompatible dimensions" 

64 assert broadcastable_to(self.shape, (mat.shape[0],)), "Incompatible self shape" 

65 N, M = mat.shape 

66 out = torch.empty((N,), device=mat.device, dtype=mat.dtype) 

67 self = self.broadcast_to(out.shape) 

68 grid = lambda META: (triton.cdiv(N, META["BLOCK_N"]),) 

69 with torch_device_fn.device(mat.device): 

70 addmv_kernel[grid]( 

71 mat, 

72 vec, 

73 self, 

74 out, 

75 N, 

76 M, 

77 alpha, 

78 beta, 

79 mat.stride(0), 

80 mat.stride(1), 

81 vec.stride(0), 

82 self.stride(0), 

83 out.stride(0), 

84 ) 

85 return out 

86 

87 

88def addmv_out(self, mat, vec, *, beta=1, alpha=1, out=None): 

89 logger.debug("GEMS ADDMV OUT") 

90 assert mat.shape[1] == vec.shape[0], "incompatible dimensions" 

91 assert broadcastable_to(self.shape, (mat.shape[0],)), "Incompatible self shape" 

92 N, M = mat.shape 

93 if out is None: 

94 out = torch.empty((N,), device=mat.device, dtype=mat.dtype) 

95 else: 

96 assert out.shape == (N,), "Incompatible output shape" 

97 

98 self = self.broadcast_to(out.shape) 

99 grid = lambda META: (triton.cdiv(N, META["BLOCK_N"]),) 

100 with torch_device_fn.device(mat.device): 

101 addmv_kernel[grid]( 

102 mat, 

103 vec, 

104 self, 

105 out, 

106 N, 

107 M, 

108 alpha, 

109 beta, 

110 mat.stride(0), 

111 mat.stride(1), 

112 vec.stride(0), 

113 self.stride(0), 

114 out.stride(0), 

115 ) 

116 return out