Coverage for src/flag_gems/runtime/backend/_arm/ops/int_mm.py: 0%
125 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1"""
2FlagGems ARM backend: Triton-CPU INT8 matmul for aten::_int_mm.
4Replaces the scalar fallback of aten::_int_mm on CPU with a Triton-CPU
5SVE2 i8mm kernel on ARM64 (CIX P1 CD8180, SVE2 + i8mm).
7Interface:
8 aten::_int_mm(Tensor self: int8, Tensor mat2: int8) -> Tensor: int32
9 self : [M, K] int8 — already-quantised activation
10 mat2 : [K, N] int8 — weight (column-major, i.e. row-major [K,N])
11 output: [M, N] int32
13Use cases covered:
14 - torchao Int8DynamicActivationInt8WeightConfig
15 - Any code that calls torch._int_mm / torch.ops.aten._int_mm on CPU
17Routing (same M-branch + padding strategy as quantized_linear_dynamic.py):
18 M==1 → BM=1, BK=4 (ConvertDotGeneric, LLVM unrolls K loop)
19 M==2 → BM=2, BK=4 (2-row ConvertDotGeneric)
20 M%64==0 → BM=64, BK=32 (SVE2 i8mm Dynamic ForOp, ~411 GOPS)
21 M%8==0 → BM=8, BK=32 (SVE2 i8mm Dynamic ForOp, ~100-170 GOPS)
22 otherwise → pad to M%8==0, BM=8, BK=32 (e.g. M=84→88)
24Unlike quantized_linear_dynamic, no weight cache or quant/dequant fusion
25is needed: inputs are already int8, output is int32.
27Scalar baseline: 1.9 GOPS (OMP=8 has no effect).
28Triton target: M=1 → 63 GOPS, M=64 → 411 GOPS, M=84→88 → 170 GOPS.
29"""
31import logging
32import os
34import torch
35import triton
36import triton.language as tl
37from triton.language.extra.cpu.tle_ops import sdot_gemv as _tle_sdot_gemv
38from triton.language.extra.cpu.tle_ops import (
39 sdot_gemv_fused_bf16 as _tle_sdot_gemv_fused_bf16,
40)
41from triton.language.extra.cpu.tle_ops import (
42 sdot_pack_weights as _tle_sdot_pack_weights,
43)
45logger = logging.getLogger(__name__)
48# ---------------------------------------------------------------------------
49# Triton kernel: int8 @ int8 → int32 (row-major weights, BK-loop)
50# Reuses same pattern as _i8mm_kernel in quantized_linear_dynamic.
51# ---------------------------------------------------------------------------
54@triton.jit
55def _int8mm_kernel(
56 a_ptr,
57 b_ptr,
58 c_ptr,
59 M,
60 N,
61 K,
62 stride_am,
63 stride_ak,
64 stride_bk,
65 stride_bn,
66 stride_cm,
67 stride_cn,
68 BLOCK_M: tl.constexpr,
69 BLOCK_N: tl.constexpr,
70 BLOCK_K: tl.constexpr,
71):
72 """int8 GEMM: A[M,K] int8 @ B[K,N] int8 → C[M,N] int32."""
73 pid_m = tl.program_id(0)
74 pid_n = tl.program_id(1)
75 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
76 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
78 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
79 for k in range(0, tl.cdiv(K, BLOCK_K)):
80 offs_k = k * BLOCK_K + tl.arange(0, BLOCK_K)
81 a = tl.load(a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
82 b = tl.load(b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn)
83 acc += tl.dot(a, b)
85 tl.store(
86 c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn,
87 acc,
88 )
91# ---------------------------------------------------------------------------
92# Weight cache: torchao int8dq provides col-major weights that need
93# .contiguous() to become row-major for the Triton kernel. Without caching,
94# this copy (3-11ms per call) dominates every token. Cache by data_ptr()
95# so each weight is made contiguous only once (first call per layer).
96# ---------------------------------------------------------------------------
97_INT_MM_B_CACHE: dict = {}
99# ---------------------------------------------------------------------------
100# NEON SDOT for M=1 INT8 GEMV via TLE @triton.jit ops (create_cpu_sdot_*).
101# Pre-packed weights in SDOT-friendly format: B_packed[K//4, N//4, 4, 4]
102# where B_packed[kb, nb, ni, ki] = B_original[kb*4+ki, nb*4+ni].
103# Each TLE op is coarse (whole pack / whole GEMV = one kernel launch), no ctypes.
104# ---------------------------------------------------------------------------
105_SDOT_WEIGHT_CACHE: dict = {} # (data_ptr, K, N) -> (B_packed, b_ref)
106# None = not yet tried, True = TLE sdot path works, False = fall back to Triton.
107_SDOT_TLE_OK = None
110@triton.jit
111def _sdot_pack_kernel(b_ptr, packed_ptr, K: tl.constexpr, N: tl.constexpr):
112 _tle_sdot_pack_weights(b_ptr, packed_ptr, K, N)
115@triton.jit
116def _sdot_gemv_kernel(a_ptr, packed_ptr, c_ptr, K: tl.constexpr, N: tl.constexpr):
117 _tle_sdot_gemv(a_ptr, packed_ptr, c_ptr, K, N)
120@triton.jit
121def _sdot_gemv_fused_bf16_kernel(
122 x_ptr, packed_ptr, ws_ptr, out_ptr, K: tl.constexpr, N: tl.constexpr
123):
124 _tle_sdot_gemv_fused_bf16(x_ptr, packed_ptr, ws_ptr, out_ptr, K, N)
127def _sdot_enabled():
128 return os.getenv("FLAGGEMS_ARM_SDOT", "1").lower() in ("1", "true", "on")
131def _get_sdot_packed_weight(b_rowmajor, K, N):
132 """Get or create SDOT pre-packed weight. Cached by (data_ptr, K, N).
134 Holds a reference to the original tensor to prevent GC from reusing
135 the data_ptr address, which would cause stale cache hits.
136 """
137 key = (b_rowmajor.data_ptr(), K, N)
138 if key in _SDOT_WEIGHT_CACHE:
139 return _SDOT_WEIGHT_CACHE[key][0]
140 packed = torch.empty(K // 4, N // 4, 4, 4, dtype=torch.int8)
141 _sdot_pack_kernel[(1,)](b_rowmajor, packed, K=K, N=N)
142 _SDOT_WEIGHT_CACHE[key] = (packed, b_rowmajor) # hold ref to prevent GC
143 return packed
146def launch_sdot_fused_bf16(x_bf16, b_rowmajor, w_scale, K, N):
147 """Fused BF16→INT8 quant + SDOT GEMV + dequant→BF16 via TLE NEON (neon.py).
149 Args:
150 x_bf16: [K] bfloat16 activation (1D, contiguous)
151 b_rowmajor: [K, N] int8 weight (row-major, will be pre-packed and cached)
152 w_scale: [N] float32 per-channel weight scale
153 K, N: dimensions
155 Returns:
156 [N] bfloat16 output, or None if not applicable.
157 """
158 global _SDOT_TLE_OK
159 if _SDOT_TLE_OK is False or not _sdot_enabled():
160 return None
161 if K % 4 != 0 or N % 4 != 0:
162 return None
163 try:
164 packed = _get_sdot_packed_weight(b_rowmajor, K, N)
165 out = torch.empty(N, dtype=torch.bfloat16)
166 _sdot_gemv_fused_bf16_kernel[(1,)](x_bf16, packed, w_scale, out, K=K, N=N)
167 _SDOT_TLE_OK = True
168 return out
169 except Exception:
170 _SDOT_TLE_OK = False
171 return None
174def _launch_sdot_m1(a, b_rowmajor, K, N):
175 """Launch NEON SDOT M=1 GEMV via TLE NEON (neon.py).
176 Returns [1, N] int32 or None if not applicable."""
177 global _SDOT_TLE_OK
178 if _SDOT_TLE_OK is False or not _sdot_enabled():
179 return None
180 if K % 4 != 0 or N % 4 != 0:
181 return None
182 try:
183 packed = _get_sdot_packed_weight(b_rowmajor, K, N)
184 out = torch.empty(1, N, dtype=torch.int32)
185 _sdot_gemv_kernel[(1,)](a, packed, out, K=K, N=N)
186 _SDOT_TLE_OK = True
187 return out
188 except Exception:
189 _SDOT_TLE_OK = False
190 return None
193def _triton_int_mm(self: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
194 """
195 Triton-CPU replacement for aten::_int_mm on ARM64.
197 self : [M, K] int8 — activation (changes every token, not cached)
198 mat2 : [K, N] int8 — weight (fixed after quantization, cached by data_ptr)
199 Returns [M, N] int32
200 """
201 assert (
202 self.dtype == torch.int8 and mat2.dtype == torch.int8
203 ), f"_int_mm expects int8 inputs, got {self.dtype}, {mat2.dtype}"
204 M, K = self.shape
205 K2, N = mat2.shape
206 assert K == K2, f"_int_mm shape mismatch: [{M},{K}] @ [{K2},{N}]"
208 # Activation: always contiguous (per-token, no cache)
209 a = self.contiguous()
211 # Weight: cache row-major copy — first call per layer pays the copy cost;
212 # all subsequent token decodes are ~free (dict lookup only).
213 b_key = mat2.data_ptr()
214 if b_key not in _INT_MM_B_CACHE:
215 _INT_MM_B_CACHE[b_key] = mat2.contiguous()
216 b = _INT_MM_B_CACHE[b_key]
218 BN = 64
219 BK_prefill = 32
221 # Fallback for non-BN-aligned N (uncommon in practice)
222 if N % BN != 0:
223 logger.debug("FlagGems _int_mm: N=%d not %%64, using int32 fallback", N)
224 return a.to(torch.int32) @ b.to(torch.int32)
226 # ------------------------------------------------------------------
227 # Decode M=1: NEON SDOT with pre-packed weights via torch.ops custom op.
228 # Pre-packs B[K,N] → B_packed[K//4, N//4, 4, 4] SDOT lane format.
229 # Uses K-outer loop for L1 cache reuse. 2.5x faster than Triton SMLAL.
230 # Falls back to Triton SMLAL if SDOT not available.
231 # ------------------------------------------------------------------
232 if M == 1:
233 sdot_result = _launch_sdot_m1(a, b, K, N)
234 if sdot_result is not None:
235 return sdot_result
237 # Fallback: Triton SMLAL (BM=1, BK=4)
238 BM, BK = 1, 4
239 out = torch.empty(M, N, dtype=torch.int32)
240 _int8mm_kernel[(1, N // BN)](
241 a,
242 b,
243 out,
244 M,
245 N,
246 K,
247 a.stride(0),
248 a.stride(1),
249 b.stride(0),
250 b.stride(1),
251 out.stride(0),
252 out.stride(1),
253 BLOCK_M=BM,
254 BLOCK_N=BN,
255 BLOCK_K=BK,
256 )
257 return out
259 if M == 2:
260 BM, BK = 2, 4
261 out = torch.empty(M, N, dtype=torch.int32)
262 _int8mm_kernel[(1, N // BN)](
263 a,
264 b,
265 out,
266 M,
267 N,
268 K,
269 a.stride(0),
270 a.stride(1),
271 b.stride(0),
272 b.stride(1),
273 out.stride(0),
274 out.stride(1),
275 BLOCK_M=BM,
276 BLOCK_N=BN,
277 BLOCK_K=BK,
278 )
279 return out
281 # ------------------------------------------------------------------
282 # Prefill path (M ≥ 3): BK=32, target SVE2 i8mm Dynamic ForOp.
283 # Pad M to next multiple of 8 to unlock Dynamic ForOp for all shapes.
284 # ------------------------------------------------------------------
285 BK = BK_prefill if K % BK_prefill == 0 else 4
287 if M % 64 == 0:
288 BM = 64
289 a_kernel, M_kernel = a, M
290 elif M % 8 == 0:
291 BM = 8
292 a_kernel, M_kernel = a, M
293 else:
294 # Zero-pad to next multiple of 8
295 M_kernel = ((M + 7) // 8) * 8
296 BM = 8
297 a_kernel = torch.zeros(M_kernel, K, dtype=torch.int8)
298 a_kernel[:M].copy_(a)
300 out_kernel = torch.empty(M_kernel, N, dtype=torch.int32)
301 grid = (M_kernel // BM, N // BN)
303 _int8mm_kernel[grid](
304 a_kernel,
305 b,
306 out_kernel,
307 M_kernel,
308 N,
309 K,
310 a_kernel.stride(0),
311 a_kernel.stride(1),
312 b.stride(0),
313 b.stride(1),
314 out_kernel.stride(0),
315 out_kernel.stride(1),
316 BLOCK_M=BM,
317 BLOCK_N=BN,
318 BLOCK_K=BK,
319 )
321 return out_kernel[:M] if M_kernel != M else out_kernel
324# ---------------------------------------------------------------------------
325# Registration
326# ---------------------------------------------------------------------------
328_int_mm_lib = None # keep reference alive to prevent GC
331def register():
332 """
333 Register Triton implementation for aten::_int_mm on CPU.
334 Idempotent: safe to call multiple times.
335 """
336 global _int_mm_lib
337 if _int_mm_lib is not None:
338 return
340 try:
341 _int_mm_lib = torch.library.Library("aten", "IMPL")
342 _int_mm_lib.impl(
343 "_int_mm",
344 _triton_int_mm,
345 "CPU",
346 allow_override=True,
347 )
348 logger.debug("FlagGems ARM: registered Triton-CPU i8mm for aten::_int_mm")
349 except Exception as e:
350 logger.warning("FlagGems ARM: failed to register aten::_int_mm override: %s", e)