Coverage for src/flag_gems/runtime/backend/_ascend/ops/bmm.py: 0%
62 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
15# avoid
16@libentry()
17@triton.autotune(
18 configs=runtime.get_tuned_config("bmm"),
19 key=["M", "N", "K"],
20)
21@triton.heuristics(runtime.get_heuristic_config("bmm"))
22@triton.jit
23def bmm_kernel(
24 A,
25 B,
26 O,
27 M,
28 N,
29 K,
30 TILE_M: tl.constexpr,
31 TILE_N: tl.constexpr,
32 TILE_K: tl.constexpr,
33 GROUP_M: tl.constexpr,
34 DIVISIBLE_M: tl.constexpr,
35 DIVISIBLE_N: tl.constexpr,
36 DIVISIBLE_K: tl.constexpr,
37):
38 # batch offsets
39 pid_b = tle.program_id(2)
40 A += pid_b * M * K
41 B += pid_b * K * N
42 O += pid_b * M * N
44 pidx = tle.program_id(0)
45 pidy = tle.program_id(1)
46 if GROUP_M == 1:
47 pid_m, pid_n = pidx, pidy
48 else:
49 # reorder CTAs
50 gridx = tle.num_programs(0)
51 gridy = tle.num_programs(1)
52 pid = pidx + pidy * gridx
54 num_CTA_per_group = gridy * GROUP_M
56 group_id = pid // num_CTA_per_group
57 inner_group_id = pid % num_CTA_per_group
58 GROUP_SIZE = tl.where(
59 (group_id * GROUP_M + GROUP_M) > gridx, gridx % GROUP_M, GROUP_M
60 )
61 pid_m = group_id * GROUP_M + inner_group_id % GROUP_SIZE
62 pid_n = inner_group_id // GROUP_SIZE
64 offs_m = pid_m * TILE_M + tl.arange(0, TILE_M)
65 offs_n = pid_n * TILE_N + tl.arange(0, TILE_N)
66 offs_k = tl.arange(0, TILE_K)
68 a_ptrs = A + offs_m[:, None] * K + offs_k[None, :]
69 b_ptrs = B + offs_k[:, None] * N + offs_n[None, :]
70 o_ptrs = O + offs_m[:, None] * N + offs_n[None, :]
72 num_iters = tl.cdiv(K, TILE_K)
73 o = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
74 for i in range(num_iters):
75 mask_a = offs_k[None, :] < K - i * TILE_K
76 mask_b = offs_k[:, None] < K - i * TILE_K
77 a = tl.load(a_ptrs, mask=mask_a)
78 b = tl.load(b_ptrs, mask=mask_b)
80 a_ptrs += TILE_K
81 b_ptrs += TILE_K * N
83 o += tl.dot(a, b, allow_tf32=False)
85 mask_m = (pid_m * TILE_M + tl.arange(0, TILE_M)) < M
86 mask_n = (pid_n * TILE_N + tl.arange(0, TILE_N)) < N
87 mask_c = mask_m[:, None] & mask_n[None, :]
88 tl.store(o_ptrs, o, mask_c)
91def bmm(A, B):
92 logger.debug("GEMS_ASCEND BMM")
93 batch, M, K = A.shape
94 _, _, N = B.shape
95 A = A.contiguous()
96 B = B.contiguous()
97 out = torch.empty((batch, M, N), dtype=A.dtype, device=A.device)
99 grid_fn = lambda meta: (
100 triton.cdiv(meta["M"], meta["TILE_M"]),
101 triton.cdiv(meta["N"], meta["TILE_N"]),
102 batch,
103 )
105 with torch_device_fn.device(A.device):
106 bmm_kernel[grid_fn](A, B, out, M, N, K)
107 return out