Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/bmm.py: 0%
102 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7# from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15def heur_group_m(args):
16 if args["TILE_M"] > args["TILE_N"]:
17 return 1
18 else:
19 return (args["M"] + args["TILE_M"] - 1) // args["TILE_M"]
22def heur_divisible_m(args):
23 return args["M"] % args["TILE_M"] == 0
26def heur_divisible_n(args):
27 return args["N"] % args["TILE_N"] == 0
30def heur_divisible_k(args):
31 return args["K"] % args["TILE_K"] == 0
34@libentry()
35@triton.autotune(
36 configs=[],
37 generate_configs="bmm",
38 key=["M", "N", "K"],
39)
40@triton.heuristics(
41 {
42 "GROUP_M": heur_group_m,
43 "DIVISIBLE_M": heur_divisible_m,
44 "DIVISIBLE_N": heur_divisible_n,
45 "DIVISIBLE_K": heur_divisible_k,
46 }
47)
48@triton.jit
49def bmm_kernel(
50 A,
51 B,
52 O,
53 M,
54 N,
55 K,
56 TILE_M: tl.constexpr,
57 TILE_N: tl.constexpr,
58 TILE_K: tl.constexpr,
59 GROUP_M: tl.constexpr,
60 DIVISIBLE_M: tl.constexpr,
61 DIVISIBLE_N: tl.constexpr,
62 DIVISIBLE_K: tl.constexpr,
63):
64 # batch offsets
65 pid_b = tle.program_id(2)
66 A += pid_b * M * K
67 B += pid_b * K * N
68 O += pid_b * M * N
70 pidx = tle.program_id(0)
71 pidy = tle.program_id(1)
73 if GROUP_M == 1:
74 pid_m, pid_n = pidx, pidy
75 else:
76 # reorder CTAs
77 gridx = tle.num_programs(0)
78 gridy = tle.num_programs(1)
79 pid = pidx + pidy * gridx
81 num_CTA_per_group = gridy * GROUP_M
83 group_id = pid // num_CTA_per_group
84 inner_group_id = pid % num_CTA_per_group
85 GROUP_SIZE = tl.where(
86 (group_id * GROUP_M + GROUP_M) > gridx, gridx % GROUP_M, GROUP_M
87 )
88 pid_m = group_id * GROUP_M + inner_group_id % GROUP_SIZE
89 pid_n = inner_group_id // GROUP_SIZE
91 offs_m = pid_m * TILE_M + tl.arange(0, TILE_M)
92 offs_n = pid_n * TILE_N + tl.arange(0, TILE_N)
93 offs_k = tl.arange(0, TILE_K)
95 if not DIVISIBLE_M:
96 mask_m = offs_m < M
97 if not DIVISIBLE_N:
98 mask_n = offs_n < N
100 a_ptrs = A + offs_m[:, None] * K + offs_k[None, :]
101 b_ptrs = B + offs_k[:, None] * N + offs_n[None, :]
102 o_ptrs = O + offs_m[:, None] * N + offs_n[None, :]
104 num_iters = tl.cdiv(K, TILE_K)
105 o = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
106 for _ in range(num_iters):
107 if DIVISIBLE_K:
108 if DIVISIBLE_M:
109 mask_a = tl.full([TILE_M, TILE_K], value=1, dtype=tl.int1)
110 else:
111 mask_a = mask_m[:, None]
112 if DIVISIBLE_N:
113 mask_b = tl.full([TILE_K, TILE_N], value=1, dtype=tl.int1)
114 else:
115 mask_b = mask_n[None, :]
116 else:
117 mask_k = offs_k < K
118 offs_k += TILE_K
119 if DIVISIBLE_M:
120 mask_a = mask_k[None, :]
121 else:
122 mask_a = mask_m[:, None] & mask_k[None, :]
123 if DIVISIBLE_N:
124 mask_b = mask_k[:, None]
125 else:
126 mask_b = mask_k[:, None] & mask_n[None, :]
128 a = tl.load(a_ptrs, mask_a)
129 b = tl.load(b_ptrs, mask_b)
131 a_ptrs += TILE_K
132 b_ptrs += TILE_K * N
134 o += tl.dot(a, b, allow_tf32=False)
136 if DIVISIBLE_M and DIVISIBLE_N:
137 mask_c = tl.full([TILE_M, TILE_N], value=1, dtype=tl.int1)
138 elif DIVISIBLE_M and not DIVISIBLE_N:
139 mask_c = mask_n[None, :]
140 elif not DIVISIBLE_M and DIVISIBLE_N:
141 mask_c = mask_m[:, None]
142 else:
143 mask_c = mask_m[:, None] & mask_n[None, :]
144 tl.store(o_ptrs, o, mask_c)
147def bmm(A, B):
148 logger.debug("GEMS BMM")
149 batch, M, K = A.shape
150 _, _, N = B.shape
151 A = A.contiguous()
152 B = B.contiguous()
153 out = torch.empty((batch, M, N), dtype=A.dtype, device=A.device)
155 grid_fn = lambda meta: (
156 triton.cdiv(meta["M"], meta["TILE_M"]),
157 triton.cdiv(meta["N"], meta["TILE_N"]),
158 batch,
159 )
160 with torch_device_fn.device(A.device):
161 bmm_kernel[grid_fn](A, B, out, M, N, K)
162 return out
165def bmm_out(A, B, out):
166 logger.debug("GEMS BMM_OUT")
167 assert A.shape[0] == B.shape[0] == out.shape[0], "Batch dim mismatch"
168 assert A.shape[2] == B.shape[1], "K dim mismatch"
169 batch, M, K = A.shape
170 _, _, N = B.shape
172 grid_fn = lambda meta: (
173 triton.cdiv(meta["M"], meta["TILE_M"]),
174 triton.cdiv(meta["N"], meta["TILE_N"]),
175 batch,
176 )
177 with torch_device_fn.device(A.device):
178 bmm_kernel[grid_fn](A, B, out, M, N, K)
179 return out