Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/baddbmm.py: 0%

148 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-17 02:35 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry, libtuner 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12if runtime.device.vendor_name == "iluvatar": 

13 from flag_gems.runtime.backend._iluvatar.ops.bmm import bmm 

14else: 

15 from .bmm import bmm 

16 

17from .mul import mul 

18 

19logger = logging.getLogger(__name__) 

20 

21 

22@libentry() 

23@libtuner( 

24 configs=runtime.get_tuned_config("baddbmm"), 

25 key=["M", "N", "K"], 

26 strategy=["align32", "align32", "align32"], 

27 warmup=5, 

28 rep=10, 

29) 

30@triton.heuristics(runtime.get_heuristic_config("baddbmm")) 

31@triton.jit(do_not_specialize=["alpha", "beta"]) 

32def baddbmm_kernel( 

33 A, 

34 B, 

35 O, 

36 bias, 

37 alpha, 

38 beta, 

39 M, 

40 N, 

41 K, 

42 TILE_M: tl.constexpr, 

43 TILE_N: tl.constexpr, 

44 TILE_K: tl.constexpr, 

45 GROUP_M: tl.constexpr, 

46 DIVISIBLE_M: tl.constexpr, 

47 DIVISIBLE_N: tl.constexpr, 

48 DIVISIBLE_K: tl.constexpr, 

49 bias_batch_stride: tl.constexpr, 

50 bias_M_stride: tl.constexpr, 

51 bias_N_stride: tl.constexpr, 

52): 

53 # batch offsets 

54 pid_b = tle.program_id(2) 

55 A += pid_b * M * K 

56 B += pid_b * K * N 

57 O += pid_b * M * N 

58 bias += pid_b * bias_batch_stride 

59 

60 pidx = tle.program_id(0) 

61 pidy = tle.program_id(1) 

62 

63 if GROUP_M == 1: 

64 pid_m, pid_n = pidx, pidy 

65 else: 

66 gridx = tle.num_programs(0) 

67 gridy = tle.num_programs(1) 

68 pid = pidx + pidy * gridx 

69 num_CTA_per_group = gridy * GROUP_M 

70 group_id = pid // num_CTA_per_group 

71 inner_group_id = pid % num_CTA_per_group 

72 GROUP_SIZE = tl.where( 

73 (group_id * GROUP_M + GROUP_M) > gridx, gridx % GROUP_M, GROUP_M 

74 ) 

75 pid_m = group_id * GROUP_M + inner_group_id % GROUP_SIZE 

76 pid_n = inner_group_id // GROUP_SIZE 

77 

78 offs_m = pid_m * TILE_M + tl.arange(0, TILE_M) 

79 offs_n = pid_n * TILE_N + tl.arange(0, TILE_N) 

80 offs_k = tl.arange(0, TILE_K) 

81 

82 if not DIVISIBLE_M: 

83 mask_m = offs_m < M 

84 if not DIVISIBLE_N: 

85 mask_n = offs_n < N 

86 

87 a_ptrs = A + offs_m[:, None] * K + offs_k[None, :] 

88 b_ptrs = B + offs_k[:, None] * N + offs_n[None, :] 

89 o_ptrs = O + offs_m[:, None] * N + offs_n[None, :] 

90 

91 num_iters = tl.cdiv(K, TILE_K) 

92 accumulator = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) 

93 for _ in range(num_iters): 

94 if DIVISIBLE_K: 

95 if DIVISIBLE_M: 

96 mask_a = None 

97 else: 

98 mask_a = mask_m[:, None] 

99 if DIVISIBLE_N: 

100 mask_b = None 

101 else: 

102 mask_b = mask_n[None, :] 

103 else: 

104 mask_k = offs_k < K 

105 if DIVISIBLE_M: 

106 mask_a = mask_k[None, :] 

107 else: 

108 mask_a = mask_m[:, None] & mask_k[None, :] 

109 if DIVISIBLE_N: 

110 mask_b = mask_k[:, None] 

111 else: 

112 mask_b = mask_k[:, None] & mask_n[None, :] 

113 a = tl.load(a_ptrs, mask=mask_a) 

114 b = tl.load(b_ptrs, mask=mask_b) 

115 accumulator += tl.dot(a, b, allow_tf32=False) 

116 offs_k += TILE_K 

117 a_ptrs += TILE_K 

118 b_ptrs += TILE_K * N 

119 

120 bias_ptrs = bias + offs_m[:, None] * bias_M_stride + offs_n[None, :] * bias_N_stride 

121 

122 if DIVISIBLE_M and DIVISIBLE_N: 

123 mask_c = None 

124 else: 

125 mask_c = True 

126 if not DIVISIBLE_M: 

127 mask_c &= offs_m[:, None] < M 

128 if not DIVISIBLE_N: 

129 mask_c &= offs_n[None, :] < N 

130 

131 bi = tl.load(bias_ptrs, mask=mask_c) 

132 out = accumulator * alpha + bi * beta 

133 o = out.to(bi.dtype) 

134 tl.store(o_ptrs, o, mask=mask_c) 

135 

136 

137class BaddbmmFunction(torch.autograd.Function): 

138 @staticmethod 

139 def forward(ctx, bias, A, B, beta, alpha): 

140 logger.debug("GEMS BADDBMM FORWARD") 

141 

142 ctx.save_for_backward(A, B, bias) 

143 ctx.alpha = alpha 

144 ctx.beta = beta 

145 

146 batch, M, K = A.shape 

147 _, _, N = B.shape 

148 A = A.contiguous() 

149 B = B.contiguous() 

150 out = torch.empty((batch, M, N), dtype=A.dtype, device=A.device) 

151 

152 bbias = torch.broadcast_to(bias, (batch, M, N)).contiguous() 

153 bias_batch_stride = bbias.stride(0) 

154 bias_M_stride = bbias.stride(1) 

155 bias_N_stride = bbias.stride(-1) 

156 

157 grid = lambda meta: ( 

158 triton.cdiv(meta["M"], meta["TILE_M"]), 

159 triton.cdiv(meta["N"], meta["TILE_N"]), 

160 batch, 

161 ) 

162 with torch_device_fn.device(A.device): 

163 baddbmm_kernel[grid]( 

164 A, 

165 B, 

166 out, 

167 bbias, 

168 alpha, 

169 beta, 

170 M, 

171 N, 

172 K, 

173 bias_batch_stride=bias_batch_stride, 

174 bias_M_stride=bias_M_stride, 

175 bias_N_stride=bias_N_stride, 

176 ) 

177 return out 

178 

179 @staticmethod 

180 def backward(ctx, grad_output): 

181 logger.debug("GEMS BADDBMM BACKWARD") 

182 A, B, bias = ctx.saved_tensors 

183 

184 grad_A = None 

185 grad_B = None 

186 grad_bias = None 

187 if ctx.needs_input_grad[0]: 

188 grad_bias = compute_bias_grad(grad_output, ctx.beta, bias) 

189 if ctx.needs_input_grad[1]: 

190 grad_A = compute_A_grad(grad_output, B, ctx.alpha) 

191 if ctx.needs_input_grad[2]: 

192 grad_B = compute_B_grad(A, grad_output, ctx.alpha) 

193 

194 return grad_bias, grad_A, grad_B, None, None 

195 

196 

197def compute_bias_grad(d_output, beta, bias): 

198 grad_bias = mul(d_output, beta) 

199 if grad_bias.shape != bias.shape: 

200 # Sum over broadcasted dimensions 

201 while grad_bias.dim() > bias.dim(): 

202 grad_bias = grad_bias.sum(dim=0) 

203 for i in range(bias.dim()): 

204 if bias.shape[i] == 1 and grad_bias.shape[i] > 1: 

205 grad_bias = grad_bias.sum(dim=i, keepdim=True) 

206 return grad_bias.view(bias.shape) 

207 

208 

209def compute_A_grad(d_output, B, alpha): 

210 B_T = B.transpose(1, 2) 

211 if B.dtype == torch.float16: 

212 Bcopy = B_T.to(torch.float32) 

213 dcopye = d_output.to(torch.float32) 

214 mul1 = bmm(dcopye, Bcopy) 

215 grad_A = mul(mul1, alpha) 

216 grad_A = grad_A.to(torch.float16) 

217 else: 

218 mul1 = bmm(d_output, B_T) 

219 grad_A = mul(mul1, alpha) 

220 return grad_A 

221 

222 

223def compute_B_grad(A, d_output, alpha): 

224 A_T = A.transpose(1, 2) 

225 if A.dtype == torch.float16: 

226 Acopy = A_T.to(torch.float32) 

227 dcopye = d_output.to(torch.float32) 

228 mul2 = bmm(Acopy, dcopye) 

229 grad_B = mul(mul2, alpha) 

230 grad_B = grad_B.to(torch.float16) 

231 else: 

232 mul2 = bmm(A_T, d_output) 

233 grad_B = mul(mul2, alpha) 

234 return grad_B 

235 

236 

237def baddbmm(bias, A, B, beta=1.0, alpha=1.0): 

238 return BaddbmmFunction.apply( 

239 bias.contiguous(), 

240 A.contiguous(), 

241 B.contiguous(), 

242 beta, 

243 alpha, 

244 )