Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/baddbmm.py: 0%
148 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry, libtuner
10from flag_gems.utils import triton_lang_extension as tle
12if runtime.device.vendor_name == "iluvatar":
13 from flag_gems.runtime.backend._iluvatar.ops.bmm import bmm
14else:
15 from .bmm import bmm
17from .mul import mul
19logger = logging.getLogger(__name__)
22@libentry()
23@libtuner(
24 configs=runtime.get_tuned_config("baddbmm"),
25 key=["M", "N", "K"],
26 strategy=["align32", "align32", "align32"],
27 warmup=5,
28 rep=10,
29)
30@triton.heuristics(runtime.get_heuristic_config("baddbmm"))
31@triton.jit(do_not_specialize=["alpha", "beta"])
32def baddbmm_kernel(
33 A,
34 B,
35 O,
36 bias,
37 alpha,
38 beta,
39 M,
40 N,
41 K,
42 TILE_M: tl.constexpr,
43 TILE_N: tl.constexpr,
44 TILE_K: tl.constexpr,
45 GROUP_M: tl.constexpr,
46 DIVISIBLE_M: tl.constexpr,
47 DIVISIBLE_N: tl.constexpr,
48 DIVISIBLE_K: tl.constexpr,
49 bias_batch_stride: tl.constexpr,
50 bias_M_stride: tl.constexpr,
51 bias_N_stride: tl.constexpr,
52):
53 # batch offsets
54 pid_b = tle.program_id(2)
55 A += pid_b * M * K
56 B += pid_b * K * N
57 O += pid_b * M * N
58 bias += pid_b * bias_batch_stride
60 pidx = tle.program_id(0)
61 pidy = tle.program_id(1)
63 if GROUP_M == 1:
64 pid_m, pid_n = pidx, pidy
65 else:
66 gridx = tle.num_programs(0)
67 gridy = tle.num_programs(1)
68 pid = pidx + pidy * gridx
69 num_CTA_per_group = gridy * GROUP_M
70 group_id = pid // num_CTA_per_group
71 inner_group_id = pid % num_CTA_per_group
72 GROUP_SIZE = tl.where(
73 (group_id * GROUP_M + GROUP_M) > gridx, gridx % GROUP_M, GROUP_M
74 )
75 pid_m = group_id * GROUP_M + inner_group_id % GROUP_SIZE
76 pid_n = inner_group_id // GROUP_SIZE
78 offs_m = pid_m * TILE_M + tl.arange(0, TILE_M)
79 offs_n = pid_n * TILE_N + tl.arange(0, TILE_N)
80 offs_k = tl.arange(0, TILE_K)
82 if not DIVISIBLE_M:
83 mask_m = offs_m < M
84 if not DIVISIBLE_N:
85 mask_n = offs_n < N
87 a_ptrs = A + offs_m[:, None] * K + offs_k[None, :]
88 b_ptrs = B + offs_k[:, None] * N + offs_n[None, :]
89 o_ptrs = O + offs_m[:, None] * N + offs_n[None, :]
91 num_iters = tl.cdiv(K, TILE_K)
92 accumulator = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
93 for _ in range(num_iters):
94 if DIVISIBLE_K:
95 if DIVISIBLE_M:
96 mask_a = None
97 else:
98 mask_a = mask_m[:, None]
99 if DIVISIBLE_N:
100 mask_b = None
101 else:
102 mask_b = mask_n[None, :]
103 else:
104 mask_k = offs_k < K
105 if DIVISIBLE_M:
106 mask_a = mask_k[None, :]
107 else:
108 mask_a = mask_m[:, None] & mask_k[None, :]
109 if DIVISIBLE_N:
110 mask_b = mask_k[:, None]
111 else:
112 mask_b = mask_k[:, None] & mask_n[None, :]
113 a = tl.load(a_ptrs, mask=mask_a)
114 b = tl.load(b_ptrs, mask=mask_b)
115 accumulator += tl.dot(a, b, allow_tf32=False)
116 offs_k += TILE_K
117 a_ptrs += TILE_K
118 b_ptrs += TILE_K * N
120 bias_ptrs = bias + offs_m[:, None] * bias_M_stride + offs_n[None, :] * bias_N_stride
122 if DIVISIBLE_M and DIVISIBLE_N:
123 mask_c = None
124 else:
125 mask_c = True
126 if not DIVISIBLE_M:
127 mask_c &= offs_m[:, None] < M
128 if not DIVISIBLE_N:
129 mask_c &= offs_n[None, :] < N
131 bi = tl.load(bias_ptrs, mask=mask_c)
132 out = accumulator * alpha + bi * beta
133 o = out.to(bi.dtype)
134 tl.store(o_ptrs, o, mask=mask_c)
137class BaddbmmFunction(torch.autograd.Function):
138 @staticmethod
139 def forward(ctx, bias, A, B, beta, alpha):
140 logger.debug("GEMS BADDBMM FORWARD")
142 ctx.save_for_backward(A, B, bias)
143 ctx.alpha = alpha
144 ctx.beta = beta
146 batch, M, K = A.shape
147 _, _, N = B.shape
148 A = A.contiguous()
149 B = B.contiguous()
150 out = torch.empty((batch, M, N), dtype=A.dtype, device=A.device)
152 bbias = torch.broadcast_to(bias, (batch, M, N)).contiguous()
153 bias_batch_stride = bbias.stride(0)
154 bias_M_stride = bbias.stride(1)
155 bias_N_stride = bbias.stride(-1)
157 grid = lambda meta: (
158 triton.cdiv(meta["M"], meta["TILE_M"]),
159 triton.cdiv(meta["N"], meta["TILE_N"]),
160 batch,
161 )
162 with torch_device_fn.device(A.device):
163 baddbmm_kernel[grid](
164 A,
165 B,
166 out,
167 bbias,
168 alpha,
169 beta,
170 M,
171 N,
172 K,
173 bias_batch_stride=bias_batch_stride,
174 bias_M_stride=bias_M_stride,
175 bias_N_stride=bias_N_stride,
176 )
177 return out
179 @staticmethod
180 def backward(ctx, grad_output):
181 logger.debug("GEMS BADDBMM BACKWARD")
182 A, B, bias = ctx.saved_tensors
184 grad_A = None
185 grad_B = None
186 grad_bias = None
187 if ctx.needs_input_grad[0]:
188 grad_bias = compute_bias_grad(grad_output, ctx.beta, bias)
189 if ctx.needs_input_grad[1]:
190 grad_A = compute_A_grad(grad_output, B, ctx.alpha)
191 if ctx.needs_input_grad[2]:
192 grad_B = compute_B_grad(A, grad_output, ctx.alpha)
194 return grad_bias, grad_A, grad_B, None, None
197def compute_bias_grad(d_output, beta, bias):
198 grad_bias = mul(d_output, beta)
199 if grad_bias.shape != bias.shape:
200 # Sum over broadcasted dimensions
201 while grad_bias.dim() > bias.dim():
202 grad_bias = grad_bias.sum(dim=0)
203 for i in range(bias.dim()):
204 if bias.shape[i] == 1 and grad_bias.shape[i] > 1:
205 grad_bias = grad_bias.sum(dim=i, keepdim=True)
206 return grad_bias.view(bias.shape)
209def compute_A_grad(d_output, B, alpha):
210 B_T = B.transpose(1, 2)
211 if B.dtype == torch.float16:
212 Bcopy = B_T.to(torch.float32)
213 dcopye = d_output.to(torch.float32)
214 mul1 = bmm(dcopye, Bcopy)
215 grad_A = mul(mul1, alpha)
216 grad_A = grad_A.to(torch.float16)
217 else:
218 mul1 = bmm(d_output, B_T)
219 grad_A = mul(mul1, alpha)
220 return grad_A
223def compute_B_grad(A, d_output, alpha):
224 A_T = A.transpose(1, 2)
225 if A.dtype == torch.float16:
226 Acopy = A_T.to(torch.float32)
227 dcopye = d_output.to(torch.float32)
228 mul2 = bmm(Acopy, dcopye)
229 grad_B = mul(mul2, alpha)
230 grad_B = grad_B.to(torch.float16)
231 else:
232 mul2 = bmm(A_T, d_output)
233 grad_B = mul(mul2, alpha)
234 return grad_B
237def baddbmm(bias, A, B, beta=1.0, alpha=1.0):
238 return BaddbmmFunction.apply(
239 bias.contiguous(),
240 A.contiguous(),
241 B.contiguous(),
242 beta,
243 alpha,
244 )