Coverage for src/flag_gems/runtime/backend/_metax/ops/addmm.py: 0%

56 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import broadcastable_to, libentry, libtuner 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13logger = logging.getLogger("flag_gems." + __name__) 

14 

15 

16@libentry() 

17@libtuner( 

18 configs=runtime.get_tuned_config("addmm"), 

19 key=["M", "N", "K"], 

20) 

21@triton.heuristics( 

22 { 

23 "UPGRADE": lambda args: math.ceil( 

24 (args["M"] * args["N"]) / (args["BLOCK_SIZE_M"] * args["BLOCK_SIZE_N"]) 

25 ).bit_length() 

26 > 32, 

27 } 

28) 

29@triton.jit(do_not_specialize=["alpha", "beta"]) 

30def addmm_kernel( 

31 a_ptr, 

32 b_ptr, 

33 i_ptr, 

34 c_ptr, 

35 alpha, 

36 beta, 

37 M, 

38 N, 

39 K, 

40 stride_am, 

41 stride_ak, 

42 stride_bk, 

43 stride_bn, 

44 stride_im, 

45 stride_in, 

46 stride_cm, 

47 stride_cn, 

48 BLOCK_SIZE_M: tl.constexpr, 

49 BLOCK_SIZE_N: tl.constexpr, 

50 BLOCK_SIZE_K: tl.constexpr, 

51 UPGRADE: tl.constexpr, 

52): 

53 if UPGRADE: 

54 pid_m = tle.program_id(0) 

55 pid_n = tle.program_id(1) 

56 else: 

57 pid_m = tl.program_id(0) 

58 pid_n = tl.program_id(1) 

59 

60 offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 

61 offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

62 offs_k = tl.arange(0, BLOCK_SIZE_K) 

63 a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) 

64 b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) 

65 

66 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

67 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

68 a = tl.load( 

69 a_ptrs, 

70 mask=(offs_am[:, None] < M) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), 

71 other=0.0, 

72 ) 

73 b = tl.load( 

74 b_ptrs, 

75 mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_bn[None, :] < N), 

76 other=0.0, 

77 ) 

78 accumulator += tl.dot(a, b, allow_tf32=False) 

79 a_ptrs += BLOCK_SIZE_K * stride_ak 

80 b_ptrs += BLOCK_SIZE_K * stride_bk 

81 

82 offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 

83 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

84 c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] 

85 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) 

86 i_ptrs = i_ptr + stride_im * offs_cm[:, None] + stride_in * offs_cn[None, :] 

87 bias = tl.load(i_ptrs, mask=c_mask, other=0.0) 

88 

89 accumulator = accumulator * alpha + bias * beta 

90 c = accumulator.to(bias.dtype) 

91 tl.store(c_ptrs, c, mask=c_mask) 

92 

93 

94def addmm(bias, mat1, mat2, *, beta=1, alpha=1): 

95 logger.debug("METAX GEMS ADDMM") 

96 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions" 

97 assert broadcastable_to( 

98 bias.shape, (mat1.shape[0], mat2.shape[1]) 

99 ), "Incompatible input shape" 

100 M, K = mat1.shape 

101 _, N = mat2.shape 

102 

103 logger.debug( 

104 "METAX GEMS ADDMM, [shape info]: [-, %s, %s, %s](batch, M, N, K), " 

105 "[A column-major]: %s, [B column-major]: %s, [bias column-major]: %s", 

106 M, 

107 N, 

108 K, 

109 mat1.stride(0) == 1, 

110 mat2.stride(0) == 1, 

111 bias.stride(0) == 1, 

112 ) 

113 

114 mat1 = mat1.contiguous() 

115 mat2 = mat2.contiguous() 

116 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype) 

117 bias = bias.broadcast_to(out.shape).contiguous() 

118 

119 grid = lambda META: ( 

120 triton.cdiv(M, META["BLOCK_SIZE_M"]), 

121 triton.cdiv(N, META["BLOCK_SIZE_N"]), 

122 ) 

123 with torch_device_fn.device(mat1.device): 

124 addmm_kernel[grid]( 

125 mat1, 

126 mat2, 

127 bias, 

128 out, 

129 alpha, 

130 beta, 

131 M, 

132 N, 

133 K, 

134 mat1.stride(0), 

135 mat1.stride(1), 

136 mat2.stride(0), 

137 mat2.stride(1), 

138 bias.stride(0), 

139 bias.stride(1), 

140 out.stride(0), 

141 out.stride(1), 

142 ) 

143 return out