Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/addmm.py: 0%

63 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-07 22:33 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7# from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import broadcastable_to, libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13 

14 

15@libentry() 

16@triton.autotune( 

17 configs=[], 

18 generate_configs="addmm", 

19 key=["M", "N", "K"], 

20) 

21@triton.jit(do_not_specialize=["alpha", "beta"]) 

22def addmm_kernel( 

23 a_ptr, 

24 b_ptr, 

25 i_ptr, 

26 c_ptr, 

27 alpha, 

28 beta, 

29 M, 

30 N, 

31 K, 

32 stride_am, 

33 stride_ak, 

34 stride_bk, 

35 stride_bn, 

36 stride_im, 

37 stride_in, 

38 stride_cm, 

39 stride_cn, 

40 BLOCK_SIZE_M: tl.constexpr, 

41 BLOCK_SIZE_N: tl.constexpr, 

42 BLOCK_SIZE_K: tl.constexpr, 

43): 

44 pid_m = tle.program_id(0) 

45 pid_n = tle.program_id(1) 

46 

47 offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 

48 offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

49 offs_k = tl.arange(0, BLOCK_SIZE_K) 

50 a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) 

51 b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) 

52 

53 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

54 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

55 a = tl.load(a_ptrs) 

56 b = tl.load(b_ptrs) 

57 accumulator += tl.dot(a, b, allow_tf32=False) 

58 a_ptrs += BLOCK_SIZE_K * stride_ak 

59 b_ptrs += BLOCK_SIZE_K * stride_bk 

60 

61 offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 

62 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

63 c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] 

64 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) 

65 i_ptrs = i_ptr + stride_im * offs_cm[:, None] + stride_in * offs_cn[None, :] 

66 bias = tl.load(i_ptrs, mask=c_mask, other=0.0) 

67 

68 accumulator = accumulator * alpha + bias * beta 

69 c = accumulator.to(bias.dtype) 

70 tl.store(c_ptrs, c, mask=c_mask) 

71 

72 

73def addmm(bias, mat1, mat2, *, beta=1.0, alpha=1.0): 

74 logger.debug("GEMS ADDMM") 

75 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions" 

76 assert broadcastable_to( 

77 bias.shape, (mat1.shape[0], mat2.shape[1]) 

78 ), "Incompatible input shape" 

79 M, K = mat1.shape 

80 _, N = mat2.shape 

81 

82 mat1 = mat1.contiguous() 

83 # mat2 = mat2.contiguous() 

84 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype) 

85 bias = bias.broadcast_to(out.shape) 

86 

87 grid = lambda META: ( 

88 triton.cdiv(M, META["BLOCK_SIZE_M"]), 

89 triton.cdiv(N, META["BLOCK_SIZE_N"]), 

90 ) 

91 with torch_device_fn.device(mat1.device): 

92 addmm_kernel[grid]( 

93 mat1, 

94 mat2, 

95 bias, 

96 out, 

97 alpha, 

98 beta, 

99 M, 

100 N, 

101 K, 

102 mat1.stride(0), 

103 mat1.stride(1), 

104 mat2.stride(0), 

105 mat2.stride(1), 

106 bias.stride(0), 

107 bias.stride(1), 

108 out.stride(0), 

109 out.stride(1), 

110 ) 

111 return out 

112 

113 

114def addmm_out(bias, mat1, mat2, *, beta=1.0, alpha=1.0, out=None): 

115 logger.debug("GEMS ADDMM OUT") 

116 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions" 

117 assert broadcastable_to( 

118 bias.shape, (mat1.shape[0], mat2.shape[1]) 

119 ), "Incompatible input shape" 

120 M, K = mat1.shape 

121 _, N = mat2.shape 

122 if out is None: 

123 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype) 

124 else: 

125 assert out.shape == (M, N), "Incompatible output shape" 

126 

127 mat1 = mat1.contiguous() 

128 bias = bias.broadcast_to(out.shape) 

129 

130 grid = lambda META: ( 

131 triton.cdiv(M, META["BLOCK_SIZE_M"]), 

132 triton.cdiv(N, META["BLOCK_SIZE_N"]), 

133 ) 

134 with torch_device_fn.device(mat1.device): 

135 addmm_kernel[grid]( 

136 mat1, 

137 mat2, 

138 bias, 

139 out, 

140 alpha, 

141 beta, 

142 M, 

143 N, 

144 K, 

145 mat1.stride(0), 

146 mat1.stride(1), 

147 mat2.stride(0), 

148 mat2.stride(1), 

149 bias.stride(0), 

150 bias.stride(1), 

151 out.stride(0), 

152 out.stride(1), 

153 ) 

154 return out