Coverage for src/flag_gems/ops/w8a8_block_fp8_matmul.py: 49%

82 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-26 15:32 +0800

1import functools 

2import json 

3import logging 

4import os 

5from typing import Any, Dict, List, Optional 

6 

7import torch 

8import triton 

9import triton.language as tl 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14@triton.jit 

15def w8a8_block_fp8_matmul_kernel( 

16 A, 

17 B, 

18 C, 

19 As, 

20 Bs, 

21 M, 

22 N, 

23 K, 

24 group_n, 

25 group_k, 

26 stride_am, 

27 stride_ak, 

28 stride_bk, 

29 stride_bn, 

30 stride_cm, 

31 stride_cn, 

32 stride_As_m, 

33 stride_As_k, 

34 stride_Bs_k, 

35 stride_Bs_n, 

36 BLOCK_SIZE_M: tl.constexpr, 

37 BLOCK_SIZE_N: tl.constexpr, 

38 BLOCK_SIZE_K: tl.constexpr, 

39 GROUP_SIZE_M: tl.constexpr, 

40): 

41 pid = tl.program_id(axis=0) 

42 num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) 

43 num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) 

44 num_pid_in_group = GROUP_SIZE_M * num_pid_n 

45 group_id = pid // num_pid_in_group 

46 first_pid_m = group_id * GROUP_SIZE_M 

47 group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 

48 pid_m = first_pid_m + (pid % group_size_m) 

49 pid_n = (pid % num_pid_in_group) // group_size_m 

50 

51 offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M 

52 offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N 

53 offs_k = tl.arange(0, BLOCK_SIZE_K) 

54 a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) 

55 b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) 

56 

57 As_ptrs = As + offs_am * stride_As_m 

58 offs_bsn = offs_bn // group_n 

59 Bs_ptrs = Bs + offs_bsn * stride_Bs_n 

60 

61 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

62 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

63 a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) 

64 b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) 

65 

66 k_start = k * BLOCK_SIZE_K 

67 offs_ks = k_start // group_k 

68 a_s = tl.load(As_ptrs + offs_ks * stride_As_k) 

69 b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) 

70 

71 accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] 

72 a_ptrs += BLOCK_SIZE_K * stride_ak 

73 b_ptrs += BLOCK_SIZE_K * stride_bk 

74 

75 if C.dtype.element_ty == tl.bfloat16: 

76 c = accumulator.to(tl.bfloat16) 

77 elif C.dtype.element_ty == tl.float16: 

78 c = accumulator.to(tl.float16) 

79 else: 

80 c = accumulator.to(tl.float32) 

81 

82 offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 

83 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

84 c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] 

85 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) 

86 tl.store(c_ptrs, c, mask=c_mask) 

87 

88 

89@functools.lru_cache 

90def get_w8a8_block_fp8_configs( 

91 N: int, K: int, block_n: int, block_k: int 

92) -> Optional[Dict[int, Any]]: 

93 device_name = torch.cuda.get_device_name().replace(" ", "_") 

94 json_file_name = ( 

95 f"N={N},K={K},device_name={device_name}," 

96 f"dtype=fp8_w8a8,block_shape=[{block_n},{block_k}].json" 

97 ) 

98 

99 config_dir = os.path.join(os.path.dirname(__file__), "..", "utils", "configs") 

100 config_file_path = os.path.join(config_dir, json_file_name) 

101 

102 if os.path.exists(config_file_path): 

103 with open(config_file_path) as f: 

104 logger.info( 

105 "Using configuration from %s for W8A8 Block FP8 kernel.", 

106 config_file_path, 

107 ) 

108 return {int(key): val for key, val in json.load(f).items()} 

109 

110 logger.warning( 

111 "Using default W8A8 Block FP8 kernel config. Performance might " 

112 "be sub-optimal! Config file not found at %s", 

113 config_file_path, 

114 ) 

115 return None 

116 

117 

118def w8a8_block_fp8_matmul( 

119 A: torch.Tensor, 

120 B: torch.Tensor, 

121 As: torch.Tensor, 

122 Bs: torch.Tensor, 

123 block_size: List[int], 

124 output_dtype: torch.dtype = torch.float16, 

125) -> torch.Tensor: 

126 assert len(block_size) == 2 

127 block_n, block_k = block_size[0], block_size[1] 

128 

129 assert A.shape[-1] == B.shape[-1] 

130 assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous() 

131 assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] 

132 M = A.numel() // A.shape[-1] 

133 

134 assert B.ndim == 2 and Bs.ndim == 2 

135 N, K = B.shape 

136 assert triton.cdiv(N, block_n) == Bs.shape[0] 

137 assert triton.cdiv(K, block_k) == Bs.shape[1] 

138 

139 C_shape = A.shape[:-1] + (N,) 

140 C = A.new_empty(C_shape, dtype=output_dtype) 

141 

142 configs = get_w8a8_block_fp8_configs(N, K, block_n, block_k) 

143 if configs: 

144 config = configs[min(configs.keys(), key=lambda x: abs(x - M))] 

145 else: 

146 config = { 

147 "BLOCK_SIZE_M": 64, 

148 "BLOCK_SIZE_N": block_size[0], 

149 "BLOCK_SIZE_K": block_size[1], 

150 "GROUP_SIZE_M": 32, 

151 "num_warps": 4, 

152 "num_stages": 2, 

153 } 

154 

155 def grid(META): 

156 return ( 

157 triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), 

158 ) 

159 

160 w8a8_block_fp8_matmul_kernel[grid]( 

161 A, 

162 B, 

163 C, 

164 As, 

165 Bs, 

166 M, 

167 N, 

168 K, 

169 block_n, 

170 block_k, 

171 A.stride(-2), 

172 A.stride(-1), 

173 B.stride(1), 

174 B.stride(0), 

175 C.stride(-2), 

176 C.stride(-1), 

177 As.stride(-2), 

178 As.stride(-1), 

179 Bs.stride(1), 

180 Bs.stride(0), 

181 **config, 

182 ) 

183 

184 return C