Coverage for src/flag_gems/runtime/backend/_arm/ops/int_mm.py: 0%

125 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1""" 

2FlagGems ARM backend: Triton-CPU INT8 matmul for aten::_int_mm. 

3 

4Replaces the scalar fallback of aten::_int_mm on CPU with a Triton-CPU 

5SVE2 i8mm kernel on ARM64 (CIX P1 CD8180, SVE2 + i8mm). 

6 

7Interface: 

8 aten::_int_mm(Tensor self: int8, Tensor mat2: int8) -> Tensor: int32 

9 self : [M, K] int8 — already-quantised activation 

10 mat2 : [K, N] int8 — weight (column-major, i.e. row-major [K,N]) 

11 output: [M, N] int32 

12 

13Use cases covered: 

14 - torchao Int8DynamicActivationInt8WeightConfig 

15 - Any code that calls torch._int_mm / torch.ops.aten._int_mm on CPU 

16 

17Routing (same M-branch + padding strategy as quantized_linear_dynamic.py): 

18 M==1 → BM=1, BK=4 (ConvertDotGeneric, LLVM unrolls K loop) 

19 M==2 → BM=2, BK=4 (2-row ConvertDotGeneric) 

20 M%64==0 → BM=64, BK=32 (SVE2 i8mm Dynamic ForOp, ~411 GOPS) 

21 M%8==0 → BM=8, BK=32 (SVE2 i8mm Dynamic ForOp, ~100-170 GOPS) 

22 otherwise → pad to M%8==0, BM=8, BK=32 (e.g. M=84→88) 

23 

24Unlike quantized_linear_dynamic, no weight cache or quant/dequant fusion 

25is needed: inputs are already int8, output is int32. 

26 

27Scalar baseline: 1.9 GOPS (OMP=8 has no effect). 

28Triton target: M=1 → 63 GOPS, M=64 → 411 GOPS, M=84→88 → 170 GOPS. 

29""" 

30 

31import logging 

32import os 

33 

34import torch 

35import triton 

36import triton.language as tl 

37from triton.language.extra.cpu.tle_ops import sdot_gemv as _tle_sdot_gemv 

38from triton.language.extra.cpu.tle_ops import ( 

39 sdot_gemv_fused_bf16 as _tle_sdot_gemv_fused_bf16, 

40) 

41from triton.language.extra.cpu.tle_ops import ( 

42 sdot_pack_weights as _tle_sdot_pack_weights, 

43) 

44 

45logger = logging.getLogger(__name__) 

46 

47 

48# --------------------------------------------------------------------------- 

49# Triton kernel: int8 @ int8 → int32 (row-major weights, BK-loop) 

50# Reuses same pattern as _i8mm_kernel in quantized_linear_dynamic. 

51# --------------------------------------------------------------------------- 

52 

53 

54@triton.jit 

55def _int8mm_kernel( 

56 a_ptr, 

57 b_ptr, 

58 c_ptr, 

59 M, 

60 N, 

61 K, 

62 stride_am, 

63 stride_ak, 

64 stride_bk, 

65 stride_bn, 

66 stride_cm, 

67 stride_cn, 

68 BLOCK_M: tl.constexpr, 

69 BLOCK_N: tl.constexpr, 

70 BLOCK_K: tl.constexpr, 

71): 

72 """int8 GEMM: A[M,K] int8 @ B[K,N] int8 → C[M,N] int32.""" 

73 pid_m = tl.program_id(0) 

74 pid_n = tl.program_id(1) 

75 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

76 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

77 

78 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) 

79 for k in range(0, tl.cdiv(K, BLOCK_K)): 

80 offs_k = k * BLOCK_K + tl.arange(0, BLOCK_K) 

81 a = tl.load(a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak) 

82 b = tl.load(b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn) 

83 acc += tl.dot(a, b) 

84 

85 tl.store( 

86 c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn, 

87 acc, 

88 ) 

89 

90 

91# --------------------------------------------------------------------------- 

92# Weight cache: torchao int8dq provides col-major weights that need 

93# .contiguous() to become row-major for the Triton kernel. Without caching, 

94# this copy (3-11ms per call) dominates every token. Cache by data_ptr() 

95# so each weight is made contiguous only once (first call per layer). 

96# --------------------------------------------------------------------------- 

97_INT_MM_B_CACHE: dict = {} 

98 

99# --------------------------------------------------------------------------- 

100# NEON SDOT for M=1 INT8 GEMV via TLE @triton.jit ops (create_cpu_sdot_*). 

101# Pre-packed weights in SDOT-friendly format: B_packed[K//4, N//4, 4, 4] 

102# where B_packed[kb, nb, ni, ki] = B_original[kb*4+ki, nb*4+ni]. 

103# Each TLE op is coarse (whole pack / whole GEMV = one kernel launch), no ctypes. 

104# --------------------------------------------------------------------------- 

105_SDOT_WEIGHT_CACHE: dict = {} # (data_ptr, K, N) -> (B_packed, b_ref) 

106# None = not yet tried, True = TLE sdot path works, False = fall back to Triton. 

107_SDOT_TLE_OK = None 

108 

109 

110@triton.jit 

111def _sdot_pack_kernel(b_ptr, packed_ptr, K: tl.constexpr, N: tl.constexpr): 

112 _tle_sdot_pack_weights(b_ptr, packed_ptr, K, N) 

113 

114 

115@triton.jit 

116def _sdot_gemv_kernel(a_ptr, packed_ptr, c_ptr, K: tl.constexpr, N: tl.constexpr): 

117 _tle_sdot_gemv(a_ptr, packed_ptr, c_ptr, K, N) 

118 

119 

120@triton.jit 

121def _sdot_gemv_fused_bf16_kernel( 

122 x_ptr, packed_ptr, ws_ptr, out_ptr, K: tl.constexpr, N: tl.constexpr 

123): 

124 _tle_sdot_gemv_fused_bf16(x_ptr, packed_ptr, ws_ptr, out_ptr, K, N) 

125 

126 

127def _sdot_enabled(): 

128 return os.getenv("FLAGGEMS_ARM_SDOT", "1").lower() in ("1", "true", "on") 

129 

130 

131def _get_sdot_packed_weight(b_rowmajor, K, N): 

132 """Get or create SDOT pre-packed weight. Cached by (data_ptr, K, N). 

133 

134 Holds a reference to the original tensor to prevent GC from reusing 

135 the data_ptr address, which would cause stale cache hits. 

136 """ 

137 key = (b_rowmajor.data_ptr(), K, N) 

138 if key in _SDOT_WEIGHT_CACHE: 

139 return _SDOT_WEIGHT_CACHE[key][0] 

140 packed = torch.empty(K // 4, N // 4, 4, 4, dtype=torch.int8) 

141 _sdot_pack_kernel[(1,)](b_rowmajor, packed, K=K, N=N) 

142 _SDOT_WEIGHT_CACHE[key] = (packed, b_rowmajor) # hold ref to prevent GC 

143 return packed 

144 

145 

146def launch_sdot_fused_bf16(x_bf16, b_rowmajor, w_scale, K, N): 

147 """Fused BF16→INT8 quant + SDOT GEMV + dequant→BF16 via TLE NEON (neon.py). 

148 

149 Args: 

150 x_bf16: [K] bfloat16 activation (1D, contiguous) 

151 b_rowmajor: [K, N] int8 weight (row-major, will be pre-packed and cached) 

152 w_scale: [N] float32 per-channel weight scale 

153 K, N: dimensions 

154 

155 Returns: 

156 [N] bfloat16 output, or None if not applicable. 

157 """ 

158 global _SDOT_TLE_OK 

159 if _SDOT_TLE_OK is False or not _sdot_enabled(): 

160 return None 

161 if K % 4 != 0 or N % 4 != 0: 

162 return None 

163 try: 

164 packed = _get_sdot_packed_weight(b_rowmajor, K, N) 

165 out = torch.empty(N, dtype=torch.bfloat16) 

166 _sdot_gemv_fused_bf16_kernel[(1,)](x_bf16, packed, w_scale, out, K=K, N=N) 

167 _SDOT_TLE_OK = True 

168 return out 

169 except Exception: 

170 _SDOT_TLE_OK = False 

171 return None 

172 

173 

174def _launch_sdot_m1(a, b_rowmajor, K, N): 

175 """Launch NEON SDOT M=1 GEMV via TLE NEON (neon.py). 

176 Returns [1, N] int32 or None if not applicable.""" 

177 global _SDOT_TLE_OK 

178 if _SDOT_TLE_OK is False or not _sdot_enabled(): 

179 return None 

180 if K % 4 != 0 or N % 4 != 0: 

181 return None 

182 try: 

183 packed = _get_sdot_packed_weight(b_rowmajor, K, N) 

184 out = torch.empty(1, N, dtype=torch.int32) 

185 _sdot_gemv_kernel[(1,)](a, packed, out, K=K, N=N) 

186 _SDOT_TLE_OK = True 

187 return out 

188 except Exception: 

189 _SDOT_TLE_OK = False 

190 return None 

191 

192 

193def _triton_int_mm(self: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: 

194 """ 

195 Triton-CPU replacement for aten::_int_mm on ARM64. 

196 

197 self : [M, K] int8 — activation (changes every token, not cached) 

198 mat2 : [K, N] int8 — weight (fixed after quantization, cached by data_ptr) 

199 Returns [M, N] int32 

200 """ 

201 assert ( 

202 self.dtype == torch.int8 and mat2.dtype == torch.int8 

203 ), f"_int_mm expects int8 inputs, got {self.dtype}, {mat2.dtype}" 

204 M, K = self.shape 

205 K2, N = mat2.shape 

206 assert K == K2, f"_int_mm shape mismatch: [{M},{K}] @ [{K2},{N}]" 

207 

208 # Activation: always contiguous (per-token, no cache) 

209 a = self.contiguous() 

210 

211 # Weight: cache row-major copy — first call per layer pays the copy cost; 

212 # all subsequent token decodes are ~free (dict lookup only). 

213 b_key = mat2.data_ptr() 

214 if b_key not in _INT_MM_B_CACHE: 

215 _INT_MM_B_CACHE[b_key] = mat2.contiguous() 

216 b = _INT_MM_B_CACHE[b_key] 

217 

218 BN = 64 

219 BK_prefill = 32 

220 

221 # Fallback for non-BN-aligned N (uncommon in practice) 

222 if N % BN != 0: 

223 logger.debug("FlagGems _int_mm: N=%d not %%64, using int32 fallback", N) 

224 return a.to(torch.int32) @ b.to(torch.int32) 

225 

226 # ------------------------------------------------------------------ 

227 # Decode M=1: NEON SDOT with pre-packed weights via torch.ops custom op. 

228 # Pre-packs B[K,N] → B_packed[K//4, N//4, 4, 4] SDOT lane format. 

229 # Uses K-outer loop for L1 cache reuse. 2.5x faster than Triton SMLAL. 

230 # Falls back to Triton SMLAL if SDOT not available. 

231 # ------------------------------------------------------------------ 

232 if M == 1: 

233 sdot_result = _launch_sdot_m1(a, b, K, N) 

234 if sdot_result is not None: 

235 return sdot_result 

236 

237 # Fallback: Triton SMLAL (BM=1, BK=4) 

238 BM, BK = 1, 4 

239 out = torch.empty(M, N, dtype=torch.int32) 

240 _int8mm_kernel[(1, N // BN)]( 

241 a, 

242 b, 

243 out, 

244 M, 

245 N, 

246 K, 

247 a.stride(0), 

248 a.stride(1), 

249 b.stride(0), 

250 b.stride(1), 

251 out.stride(0), 

252 out.stride(1), 

253 BLOCK_M=BM, 

254 BLOCK_N=BN, 

255 BLOCK_K=BK, 

256 ) 

257 return out 

258 

259 if M == 2: 

260 BM, BK = 2, 4 

261 out = torch.empty(M, N, dtype=torch.int32) 

262 _int8mm_kernel[(1, N // BN)]( 

263 a, 

264 b, 

265 out, 

266 M, 

267 N, 

268 K, 

269 a.stride(0), 

270 a.stride(1), 

271 b.stride(0), 

272 b.stride(1), 

273 out.stride(0), 

274 out.stride(1), 

275 BLOCK_M=BM, 

276 BLOCK_N=BN, 

277 BLOCK_K=BK, 

278 ) 

279 return out 

280 

281 # ------------------------------------------------------------------ 

282 # Prefill path (M ≥ 3): BK=32, target SVE2 i8mm Dynamic ForOp. 

283 # Pad M to next multiple of 8 to unlock Dynamic ForOp for all shapes. 

284 # ------------------------------------------------------------------ 

285 BK = BK_prefill if K % BK_prefill == 0 else 4 

286 

287 if M % 64 == 0: 

288 BM = 64 

289 a_kernel, M_kernel = a, M 

290 elif M % 8 == 0: 

291 BM = 8 

292 a_kernel, M_kernel = a, M 

293 else: 

294 # Zero-pad to next multiple of 8 

295 M_kernel = ((M + 7) // 8) * 8 

296 BM = 8 

297 a_kernel = torch.zeros(M_kernel, K, dtype=torch.int8) 

298 a_kernel[:M].copy_(a) 

299 

300 out_kernel = torch.empty(M_kernel, N, dtype=torch.int32) 

301 grid = (M_kernel // BM, N // BN) 

302 

303 _int8mm_kernel[grid]( 

304 a_kernel, 

305 b, 

306 out_kernel, 

307 M_kernel, 

308 N, 

309 K, 

310 a_kernel.stride(0), 

311 a_kernel.stride(1), 

312 b.stride(0), 

313 b.stride(1), 

314 out_kernel.stride(0), 

315 out_kernel.stride(1), 

316 BLOCK_M=BM, 

317 BLOCK_N=BN, 

318 BLOCK_K=BK, 

319 ) 

320 

321 return out_kernel[:M] if M_kernel != M else out_kernel 

322 

323 

324# --------------------------------------------------------------------------- 

325# Registration 

326# --------------------------------------------------------------------------- 

327 

328_int_mm_lib = None # keep reference alive to prevent GC 

329 

330 

331def register(): 

332 """ 

333 Register Triton implementation for aten::_int_mm on CPU. 

334 Idempotent: safe to call multiple times. 

335 """ 

336 global _int_mm_lib 

337 if _int_mm_lib is not None: 

338 return 

339 

340 try: 

341 _int_mm_lib = torch.library.Library("aten", "IMPL") 

342 _int_mm_lib.impl( 

343 "_int_mm", 

344 _triton_int_mm, 

345 "CPU", 

346 allow_override=True, 

347 ) 

348 logger.debug("FlagGems ARM: registered Triton-CPU i8mm for aten::_int_mm") 

349 except Exception as e: 

350 logger.warning("FlagGems ARM: failed to register aten::_int_mm override: %s", e)