Coverage for src/flag_gems/ops/baddbmm.py: 29%
146 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from .. import runtime
8from ..runtime import torch_device_fn
9from ..utils import libentry, libtuner
10from ..utils import triton_lang_extension as tle
11from .bmm import bmm
12from .mul import mul
14logger = logging.getLogger(__name__)
17@libentry()
18@libtuner(
19 configs=runtime.get_tuned_config("baddbmm"),
20 key=["M", "N", "K"],
21 strategy=["align32", "align32", "align32"],
22 warmup=5,
23 rep=10,
24)
25@triton.heuristics(runtime.get_heuristic_config("baddbmm"))
26@triton.jit(do_not_specialize=["alpha", "beta"])
27def baddbmm_kernel(
28 A,
29 B,
30 O,
31 bias,
32 alpha,
33 beta,
34 M,
35 N,
36 K,
37 TILE_M: tl.constexpr,
38 TILE_N: tl.constexpr,
39 TILE_K: tl.constexpr,
40 GROUP_M: tl.constexpr,
41 DIVISIBLE_M: tl.constexpr,
42 DIVISIBLE_N: tl.constexpr,
43 DIVISIBLE_K: tl.constexpr,
44 bias_batch_stride: tl.constexpr,
45 bias_M_stride: tl.constexpr,
46 bias_N_stride: tl.constexpr,
47):
48 # batch offsets
49 pid_b = tle.program_id(2)
50 A += pid_b * M * K
51 B += pid_b * K * N
52 O += pid_b * M * N
53 bias += pid_b * bias_batch_stride
55 pidx = tle.program_id(0)
56 pidy = tle.program_id(1)
58 if GROUP_M == 1:
59 pid_m, pid_n = pidx, pidy
60 else:
61 gridx = tle.num_programs(0)
62 gridy = tle.num_programs(1)
63 pid = pidx + pidy * gridx
64 num_CTA_per_group = gridy * GROUP_M
65 group_id = pid // num_CTA_per_group
66 inner_group_id = pid % num_CTA_per_group
67 GROUP_SIZE = tl.where(
68 (group_id * GROUP_M + GROUP_M) > gridx, gridx % GROUP_M, GROUP_M
69 )
70 pid_m = group_id * GROUP_M + inner_group_id % GROUP_SIZE
71 pid_n = inner_group_id // GROUP_SIZE
73 offs_m = pid_m * TILE_M + tl.arange(0, TILE_M)
74 offs_n = pid_n * TILE_N + tl.arange(0, TILE_N)
75 offs_k = tl.arange(0, TILE_K)
77 if not DIVISIBLE_M:
78 mask_m = offs_m < M
79 if not DIVISIBLE_N:
80 mask_n = offs_n < N
82 a_ptrs = A + offs_m[:, None] * K + offs_k[None, :]
83 b_ptrs = B + offs_k[:, None] * N + offs_n[None, :]
84 o_ptrs = O + offs_m[:, None] * N + offs_n[None, :]
86 num_iters = tl.cdiv(K, TILE_K)
87 accumulator = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
88 for _ in range(num_iters):
89 if DIVISIBLE_K:
90 if DIVISIBLE_M:
91 mask_a = None
92 else:
93 mask_a = mask_m[:, None]
94 if DIVISIBLE_N:
95 mask_b = None
96 else:
97 mask_b = mask_n[None, :]
98 else:
99 mask_k = offs_k < K
100 if DIVISIBLE_M:
101 mask_a = mask_k[None, :]
102 else:
103 mask_a = mask_m[:, None] & mask_k[None, :]
104 if DIVISIBLE_N:
105 mask_b = mask_k[:, None]
106 else:
107 mask_b = mask_k[:, None] & mask_n[None, :]
108 a = tl.load(a_ptrs, mask=mask_a)
109 b = tl.load(b_ptrs, mask=mask_b)
110 accumulator += tl.dot(a, b, allow_tf32=False)
111 offs_k += TILE_K
112 a_ptrs += TILE_K
113 b_ptrs += TILE_K * N
115 bias_ptrs = bias + offs_m[:, None] * bias_M_stride + offs_n[None, :] * bias_N_stride
117 if DIVISIBLE_M and DIVISIBLE_N:
118 mask_c = None
119 else:
120 mask_c = True
121 if not DIVISIBLE_M:
122 mask_c &= offs_m[:, None] < M
123 if not DIVISIBLE_N:
124 mask_c &= offs_n[None, :] < N
126 bi = tl.load(bias_ptrs, mask=mask_c)
127 out = accumulator * alpha + bi * beta
128 o = out.to(bi.dtype)
129 tl.store(o_ptrs, o, mask=mask_c)
132class BaddbmmFunction(torch.autograd.Function):
133 @staticmethod
134 def forward(ctx, bias, A, B, beta, alpha):
135 logger.debug("GEMS BADDBMM FORWARD")
137 ctx.save_for_backward(A, B, bias)
138 ctx.alpha = alpha
139 ctx.beta = beta
141 batch, M, K = A.shape
142 _, _, N = B.shape
143 A = A.contiguous()
144 B = B.contiguous()
145 out = torch.empty((batch, M, N), dtype=A.dtype, device=A.device)
147 bbias = torch.broadcast_to(bias, (batch, M, N)).contiguous()
148 bias_batch_stride = bbias.stride(0)
149 bias_M_stride = bbias.stride(1)
150 bias_N_stride = bbias.stride(-1)
152 grid = lambda meta: (
153 triton.cdiv(meta["M"], meta["TILE_M"]),
154 triton.cdiv(meta["N"], meta["TILE_N"]),
155 batch,
156 )
157 with torch_device_fn.device(A.device):
158 baddbmm_kernel[grid](
159 A,
160 B,
161 out,
162 bbias,
163 alpha,
164 beta,
165 M,
166 N,
167 K,
168 bias_batch_stride=bias_batch_stride,
169 bias_M_stride=bias_M_stride,
170 bias_N_stride=bias_N_stride,
171 )
172 return out
174 @staticmethod
175 def backward(ctx, grad_output):
176 logger.debug("GEMS BADDBMM BACKWARD")
177 A, B, bias = ctx.saved_tensors
179 grad_A = None
180 grad_B = None
181 grad_bias = None
182 if ctx.needs_input_grad[0]:
183 grad_bias = compute_bias_grad(grad_output, ctx.beta, bias)
184 if ctx.needs_input_grad[1]:
185 grad_A = compute_A_grad(grad_output, B, ctx.alpha)
186 if ctx.needs_input_grad[2]:
187 grad_B = compute_B_grad(A, grad_output, ctx.alpha)
189 return grad_bias, grad_A, grad_B, None, None
192def compute_bias_grad(d_output, beta, bias):
193 grad_bias = mul(d_output, beta)
194 if grad_bias.shape != bias.shape:
195 # Sum over broadcasted dimensions
196 while grad_bias.dim() > bias.dim():
197 grad_bias = grad_bias.sum(dim=0)
198 for i in range(bias.dim()):
199 if bias.shape[i] == 1 and grad_bias.shape[i] > 1:
200 grad_bias = grad_bias.sum(dim=i, keepdim=True)
201 return grad_bias.view(bias.shape)
204def compute_A_grad(d_output, B, alpha):
205 B_T = B.transpose(1, 2)
206 if B.dtype == torch.float16:
207 Bcopy = B_T.to(torch.float32)
208 dcopye = d_output.to(torch.float32)
209 mul1 = bmm(dcopye, Bcopy)
210 grad_A = mul(mul1, alpha)
211 grad_A = grad_A.to(torch.float16)
212 else:
213 mul1 = bmm(d_output, B_T)
214 grad_A = mul(mul1, alpha)
215 return grad_A
218def compute_B_grad(A, d_output, alpha):
219 A_T = A.transpose(1, 2)
220 if A.dtype == torch.float16:
221 Acopy = A_T.to(torch.float32)
222 dcopye = d_output.to(torch.float32)
223 mul2 = bmm(Acopy, dcopye)
224 grad_B = mul(mul2, alpha)
225 grad_B = grad_B.to(torch.float16)
226 else:
227 mul2 = bmm(A_T, d_output)
228 grad_B = mul(mul2, alpha)
229 return grad_B
232def baddbmm(bias, A, B, beta=1.0, alpha=1.0):
233 return BaddbmmFunction.apply(
234 bias.contiguous(),
235 A.contiguous(),
236 B.contiguous(),
237 beta,
238 alpha,
239 )