Coverage for src/flag_gems/ops/baddbmm.py: 29%

146 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from .. import runtime 

8from ..runtime import torch_device_fn 

9from ..utils import libentry, libtuner 

10from ..utils import triton_lang_extension as tle 

11from .bmm import bmm 

12from .mul import mul 

13 

14logger = logging.getLogger(__name__) 

15 

16 

17@libentry() 

18@libtuner( 

19 configs=runtime.get_tuned_config("baddbmm"), 

20 key=["M", "N", "K"], 

21 strategy=["align32", "align32", "align32"], 

22 warmup=5, 

23 rep=10, 

24) 

25@triton.heuristics(runtime.get_heuristic_config("baddbmm")) 

26@triton.jit(do_not_specialize=["alpha", "beta"]) 

27def baddbmm_kernel( 

28 A, 

29 B, 

30 O, 

31 bias, 

32 alpha, 

33 beta, 

34 M, 

35 N, 

36 K, 

37 TILE_M: tl.constexpr, 

38 TILE_N: tl.constexpr, 

39 TILE_K: tl.constexpr, 

40 GROUP_M: tl.constexpr, 

41 DIVISIBLE_M: tl.constexpr, 

42 DIVISIBLE_N: tl.constexpr, 

43 DIVISIBLE_K: tl.constexpr, 

44 bias_batch_stride: tl.constexpr, 

45 bias_M_stride: tl.constexpr, 

46 bias_N_stride: tl.constexpr, 

47): 

48 # batch offsets 

49 pid_b = tle.program_id(2) 

50 A += pid_b * M * K 

51 B += pid_b * K * N 

52 O += pid_b * M * N 

53 bias += pid_b * bias_batch_stride 

54 

55 pidx = tle.program_id(0) 

56 pidy = tle.program_id(1) 

57 

58 if GROUP_M == 1: 

59 pid_m, pid_n = pidx, pidy 

60 else: 

61 gridx = tle.num_programs(0) 

62 gridy = tle.num_programs(1) 

63 pid = pidx + pidy * gridx 

64 num_CTA_per_group = gridy * GROUP_M 

65 group_id = pid // num_CTA_per_group 

66 inner_group_id = pid % num_CTA_per_group 

67 GROUP_SIZE = tl.where( 

68 (group_id * GROUP_M + GROUP_M) > gridx, gridx % GROUP_M, GROUP_M 

69 ) 

70 pid_m = group_id * GROUP_M + inner_group_id % GROUP_SIZE 

71 pid_n = inner_group_id // GROUP_SIZE 

72 

73 offs_m = pid_m * TILE_M + tl.arange(0, TILE_M) 

74 offs_n = pid_n * TILE_N + tl.arange(0, TILE_N) 

75 offs_k = tl.arange(0, TILE_K) 

76 

77 if not DIVISIBLE_M: 

78 mask_m = offs_m < M 

79 if not DIVISIBLE_N: 

80 mask_n = offs_n < N 

81 

82 a_ptrs = A + offs_m[:, None] * K + offs_k[None, :] 

83 b_ptrs = B + offs_k[:, None] * N + offs_n[None, :] 

84 o_ptrs = O + offs_m[:, None] * N + offs_n[None, :] 

85 

86 num_iters = tl.cdiv(K, TILE_K) 

87 accumulator = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) 

88 for _ in range(num_iters): 

89 if DIVISIBLE_K: 

90 if DIVISIBLE_M: 

91 mask_a = None 

92 else: 

93 mask_a = mask_m[:, None] 

94 if DIVISIBLE_N: 

95 mask_b = None 

96 else: 

97 mask_b = mask_n[None, :] 

98 else: 

99 mask_k = offs_k < K 

100 if DIVISIBLE_M: 

101 mask_a = mask_k[None, :] 

102 else: 

103 mask_a = mask_m[:, None] & mask_k[None, :] 

104 if DIVISIBLE_N: 

105 mask_b = mask_k[:, None] 

106 else: 

107 mask_b = mask_k[:, None] & mask_n[None, :] 

108 a = tl.load(a_ptrs, mask=mask_a) 

109 b = tl.load(b_ptrs, mask=mask_b) 

110 accumulator += tl.dot(a, b, allow_tf32=False) 

111 offs_k += TILE_K 

112 a_ptrs += TILE_K 

113 b_ptrs += TILE_K * N 

114 

115 bias_ptrs = bias + offs_m[:, None] * bias_M_stride + offs_n[None, :] * bias_N_stride 

116 

117 if DIVISIBLE_M and DIVISIBLE_N: 

118 mask_c = None 

119 else: 

120 mask_c = True 

121 if not DIVISIBLE_M: 

122 mask_c &= offs_m[:, None] < M 

123 if not DIVISIBLE_N: 

124 mask_c &= offs_n[None, :] < N 

125 

126 bi = tl.load(bias_ptrs, mask=mask_c) 

127 out = accumulator * alpha + bi * beta 

128 o = out.to(bi.dtype) 

129 tl.store(o_ptrs, o, mask=mask_c) 

130 

131 

132class BaddbmmFunction(torch.autograd.Function): 

133 @staticmethod 

134 def forward(ctx, bias, A, B, beta, alpha): 

135 logger.debug("GEMS BADDBMM FORWARD") 

136 

137 ctx.save_for_backward(A, B, bias) 

138 ctx.alpha = alpha 

139 ctx.beta = beta 

140 

141 batch, M, K = A.shape 

142 _, _, N = B.shape 

143 A = A.contiguous() 

144 B = B.contiguous() 

145 out = torch.empty((batch, M, N), dtype=A.dtype, device=A.device) 

146 

147 bbias = torch.broadcast_to(bias, (batch, M, N)).contiguous() 

148 bias_batch_stride = bbias.stride(0) 

149 bias_M_stride = bbias.stride(1) 

150 bias_N_stride = bbias.stride(-1) 

151 

152 grid = lambda meta: ( 

153 triton.cdiv(meta["M"], meta["TILE_M"]), 

154 triton.cdiv(meta["N"], meta["TILE_N"]), 

155 batch, 

156 ) 

157 with torch_device_fn.device(A.device): 

158 baddbmm_kernel[grid]( 

159 A, 

160 B, 

161 out, 

162 bbias, 

163 alpha, 

164 beta, 

165 M, 

166 N, 

167 K, 

168 bias_batch_stride=bias_batch_stride, 

169 bias_M_stride=bias_M_stride, 

170 bias_N_stride=bias_N_stride, 

171 ) 

172 return out 

173 

174 @staticmethod 

175 def backward(ctx, grad_output): 

176 logger.debug("GEMS BADDBMM BACKWARD") 

177 A, B, bias = ctx.saved_tensors 

178 

179 grad_A = None 

180 grad_B = None 

181 grad_bias = None 

182 if ctx.needs_input_grad[0]: 

183 grad_bias = compute_bias_grad(grad_output, ctx.beta, bias) 

184 if ctx.needs_input_grad[1]: 

185 grad_A = compute_A_grad(grad_output, B, ctx.alpha) 

186 if ctx.needs_input_grad[2]: 

187 grad_B = compute_B_grad(A, grad_output, ctx.alpha) 

188 

189 return grad_bias, grad_A, grad_B, None, None 

190 

191 

192def compute_bias_grad(d_output, beta, bias): 

193 grad_bias = mul(d_output, beta) 

194 if grad_bias.shape != bias.shape: 

195 # Sum over broadcasted dimensions 

196 while grad_bias.dim() > bias.dim(): 

197 grad_bias = grad_bias.sum(dim=0) 

198 for i in range(bias.dim()): 

199 if bias.shape[i] == 1 and grad_bias.shape[i] > 1: 

200 grad_bias = grad_bias.sum(dim=i, keepdim=True) 

201 return grad_bias.view(bias.shape) 

202 

203 

204def compute_A_grad(d_output, B, alpha): 

205 B_T = B.transpose(1, 2) 

206 if B.dtype == torch.float16: 

207 Bcopy = B_T.to(torch.float32) 

208 dcopye = d_output.to(torch.float32) 

209 mul1 = bmm(dcopye, Bcopy) 

210 grad_A = mul(mul1, alpha) 

211 grad_A = grad_A.to(torch.float16) 

212 else: 

213 mul1 = bmm(d_output, B_T) 

214 grad_A = mul(mul1, alpha) 

215 return grad_A 

216 

217 

218def compute_B_grad(A, d_output, alpha): 

219 A_T = A.transpose(1, 2) 

220 if A.dtype == torch.float16: 

221 Acopy = A_T.to(torch.float32) 

222 dcopye = d_output.to(torch.float32) 

223 mul2 = bmm(Acopy, dcopye) 

224 grad_B = mul(mul2, alpha) 

225 grad_B = grad_B.to(torch.float16) 

226 else: 

227 mul2 = bmm(A_T, d_output) 

228 grad_B = mul(mul2, alpha) 

229 return grad_B 

230 

231 

232def baddbmm(bias, A, B, beta=1.0, alpha=1.0): 

233 return BaddbmmFunction.apply( 

234 bias.contiguous(), 

235 A.contiguous(), 

236 B.contiguous(), 

237 beta, 

238 alpha, 

239 )