Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/addmv.py: 0%

61 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-25 02:48 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7# from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import broadcastable_to, libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13 

14 

15def heur_block_n(args): 

16 return triton.next_power_of_2(triton.cdiv(args["N"], 12)) 

17 

18 

19def heur_block_m(args): 

20 import builtins 

21 

22 return builtins.min(triton.next_power_of_2(args["M"]), 4096) 

23 

24 

25@libentry() 

26# @triton.autotune( 

27# configs=runtime.get_tuned_config("mv"), 

28# key=["M", "N"], 

29# ) 

30@triton.heuristics( 

31 { 

32 "BLOCK_N": heur_block_n, 

33 "BLOCK_M": heur_block_m, 

34 } 

35) 

36@triton.jit(do_not_specialize=["alpha", "beta"]) 

37def addmv_kernel( 

38 A, 

39 B, 

40 Inp, 

41 Out, 

42 N, 

43 M, 

44 alpha, 

45 beta, 

46 stride_an, 

47 stride_am, 

48 stride_bm, 

49 stride_in, 

50 stride_outn, 

51 BLOCK_N: tl.constexpr, 

52 BLOCK_M: tl.constexpr, 

53): 

54 pid = tle.program_id(0) 

55 offset_n = pid * BLOCK_N + tl.arange(0, BLOCK_N)[:, None] 

56 offset_m = tl.arange(0, BLOCK_M)[None, :] 

57 n_mask = offset_n < N 

58 A_ptrs = A + offset_n * stride_an + offset_m * stride_am 

59 B_ptrs = B + offset_m * stride_bm 

60 acc = tl.zeros((BLOCK_N, BLOCK_M), dtype=tl.float32) 

61 for m in range(0, M, BLOCK_M): 

62 m_mask = m + offset_m < M 

63 a = tl.load(A_ptrs, mask=n_mask & m_mask, other=0.0).to(tl.float32) 

64 b = tl.load(B_ptrs, mask=m_mask, other=0.0).to(tl.float32) 

65 acc += a * b 

66 A_ptrs += BLOCK_M * stride_am 

67 B_ptrs += BLOCK_M * stride_bm 

68 

69 acc = tl.sum(acc, axis=1)[:, None] 

70 Inp_ptrs = Inp + offset_n * stride_in 

71 inp = tl.load(Inp_ptrs, mask=n_mask, other=0.0).to(tl.float32) 

72 Out_ptrs = Out + offset_n * stride_outn 

73 out_block = acc * alpha + inp * beta 

74 tl.store(Out_ptrs, out_block, mask=n_mask) 

75 

76 

77def addmv(self, mat, vec, *, beta=1, alpha=1): 

78 logger.debug("GEMS ADDMV") 

79 assert mat.shape[1] == vec.shape[0], "incompatible dimensions" 

80 assert broadcastable_to(self.shape, (mat.shape[0],)), "Incompatible self shape" 

81 N, M = mat.shape 

82 out = torch.empty((N,), device=mat.device, dtype=mat.dtype) 

83 self = self.broadcast_to(out.shape) 

84 grid = lambda META: (triton.cdiv(N, META["BLOCK_N"]),) 

85 with torch_device_fn.device(mat.device): 

86 addmv_kernel[grid]( 

87 mat, 

88 vec, 

89 self, 

90 out, 

91 N, 

92 M, 

93 alpha, 

94 beta, 

95 mat.stride(0), 

96 mat.stride(1), 

97 vec.stride(0), 

98 self.stride(0), 

99 out.stride(0), 

100 ) 

101 return out 

102 

103 

104def addmv_out(self, mat, vec, *, beta=1, alpha=1, out=None): 

105 logger.debug("GEMS ADDMV OUT") 

106 assert mat.shape[1] == vec.shape[0], "incompatible dimensions" 

107 assert broadcastable_to(self.shape, (mat.shape[0],)), "Incompatible self shape" 

108 N, M = mat.shape 

109 if out is None: 

110 out = torch.empty((N,), device=mat.device, dtype=mat.dtype) 

111 else: 

112 assert out.shape == (N,), "Incompatible output shape" 

113 

114 self = self.broadcast_to(out.shape) 

115 grid = lambda META: (triton.cdiv(N, META["BLOCK_N"]),) 

116 with torch_device_fn.device(mat.device): 

117 addmv_kernel[grid]( 

118 mat, 

119 vec, 

120 self, 

121 out, 

122 N, 

123 M, 

124 alpha, 

125 beta, 

126 mat.stride(0), 

127 mat.stride(1), 

128 vec.stride(0), 

129 self.stride(0), 

130 out.stride(0), 

131 ) 

132 return out