Coverage for src/flag_gems/runtime/backend/_metax/ops/bmm.py: 0%
92 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry, libtuner
11from flag_gems.utils import triton_lang_extension as tle
13logger = logging.getLogger("flag_gems." + __name__)
16@libentry()
17@libtuner(
18 configs=runtime.get_tuned_config("bmm"),
19 key=["M", "N", "K"],
20)
21@triton.heuristics(runtime.get_heuristic_config("bmm"))
22@triton.heuristics(
23 {
24 "UPGRADE": lambda args: math.ceil(
25 (args["M"] * args["N"] * args["batch"]) / (args["TILE_M"] * args["TILE_M"])
26 ).bit_length()
27 > 32,
28 }
29)
30@triton.jit
31def bmm_kernel(
32 A,
33 B,
34 O,
35 M,
36 N,
37 K,
38 batch,
39 TILE_M: tl.constexpr,
40 TILE_N: tl.constexpr,
41 TILE_K: tl.constexpr,
42 GROUP_M: tl.constexpr,
43 DIVISIBLE_M: tl.constexpr,
44 DIVISIBLE_N: tl.constexpr,
45 DIVISIBLE_K: tl.constexpr,
46 UPGRADE: tl.constexpr,
47):
48 # batch offsets
49 pid_b = tle.program_id(2)
50 A += pid_b * M * K
51 B += pid_b * K * N
52 O += pid_b * M * N
54 if UPGRADE:
55 pidx = tle.program_id(0)
56 pidy = tle.program_id(1)
57 else:
58 pidx = tl.program_id(0)
59 pidy = tl.program_id(1)
61 if GROUP_M == 1:
62 pid_m, pid_n = pidx, pidy
63 else:
64 # reorder CTAs
65 if UPGRADE:
66 gridx = tle.num_programs(0)
67 gridy = tle.num_programs(1)
68 else:
69 gridx = tl.num_programs(0)
70 gridy = tl.num_programs(1)
71 pid = pidx + pidy * gridx
73 num_CTA_per_group = gridy * GROUP_M
75 group_id = pid // num_CTA_per_group
76 inner_group_id = pid % num_CTA_per_group
77 GROUP_SIZE = tl.where(
78 (group_id * GROUP_M + GROUP_M) > gridx, gridx % GROUP_M, GROUP_M
79 )
80 pid_m = group_id * GROUP_M + inner_group_id % GROUP_SIZE
81 pid_n = inner_group_id // GROUP_SIZE
83 offs_m = pid_m * TILE_M + tl.arange(0, TILE_M)
84 offs_n = pid_n * TILE_N + tl.arange(0, TILE_N)
85 offs_k = tl.arange(0, TILE_K)
87 if not DIVISIBLE_M:
88 mask_m = offs_m < M
89 if not DIVISIBLE_N:
90 mask_n = offs_n < N
92 a_ptrs = A + offs_m[:, None] * K + offs_k[None, :]
93 b_ptrs = B + offs_k[:, None] * N + offs_n[None, :]
94 o_ptrs = O + offs_m[:, None] * N + offs_n[None, :]
96 num_iters = tl.cdiv(K, TILE_K)
97 o = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
98 for _ in range(num_iters):
99 if DIVISIBLE_K:
100 if DIVISIBLE_M:
101 mask_a = None
102 else:
103 mask_a = mask_m[:, None]
104 if DIVISIBLE_N:
105 mask_b = None
106 else:
107 mask_b = mask_n[None, :]
108 else:
109 mask_k = offs_k < K
110 if DIVISIBLE_M:
111 mask_a = mask_k[None, :]
112 else:
113 mask_a = mask_m[:, None] & mask_k[None, :]
114 if DIVISIBLE_N:
115 mask_b = mask_k[:, None]
116 else:
117 mask_b = mask_k[:, None] & mask_n[None, :]
119 a = tl.load(a_ptrs, mask_a)
120 b = tl.load(b_ptrs, mask_b)
122 offs_k += TILE_K
123 a_ptrs += TILE_K
124 b_ptrs += TILE_K * N
126 o += tl.dot(a, b, allow_tf32=False)
128 if DIVISIBLE_M and DIVISIBLE_N:
129 mask_c = None
130 elif DIVISIBLE_M and not DIVISIBLE_N:
131 mask_c = mask_n[None, :]
132 elif not DIVISIBLE_M and DIVISIBLE_N:
133 mask_c = mask_m[:, None]
134 else:
135 mask_c = mask_m[:, None] & mask_n[None, :]
136 tl.store(o_ptrs, o, mask_c)
139def bmm(A, B):
140 logger.debug("METAX GEMS BMM")
141 batch, M, K = A.shape
142 _, _, N = B.shape
143 logger.debug(
144 "METAX GEMS ADDMM_OUT, [shape info]: [%s, %s, %s, %s](batch, M, N, K), "
145 "[A column-major]: %s, [B column-major]: %s",
146 batch,
147 M,
148 N,
149 K,
150 A.stride(0) == 1,
151 B.stride(0) == 1,
152 )
153 A = A.contiguous()
154 B = B.contiguous()
155 out = torch.empty((batch, M, N), dtype=A.dtype, device=A.device)
157 grid_fn = lambda meta: (
158 triton.cdiv(meta["M"], meta["TILE_M"]),
159 triton.cdiv(meta["N"], meta["TILE_N"]),
160 batch,
161 )
162 with torch_device_fn.device(A.device):
163 bmm_kernel[grid_fn](A, B, out, M, N, K, batch)
164 return out