Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/softmax.py: 0%

180 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-26 15:32 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13 

14 

15@triton.jit 

16def next_multiple_of(a, b): 

17 # the smallest x>=a that x%b ==0 

18 return tl.cdiv(a, b) * b 

19 

20 

21@triton.jit 

22def prev_multiple_of(a, b): 

23 # the largest x<a that x%b ==0 

24 return tl.cdiv(a, b) * b - b 

25 

26 

27@libentry() 

28@triton.heuristics(runtime.get_heuristic_config("softmax_inner")) 

29@triton.jit 

30def softmax_kernel_inner( 

31 output_ptr, 

32 input_ptr, 

33 M, 

34 N, 

35 TILE_N: tl.constexpr, 

36 ONE_TILE_PER_CTA: tl.constexpr, 

37): 

38 pid_m = tle.program_id(0) 

39 if ONE_TILE_PER_CTA: 

40 n_offsets = tl.arange(0, TILE_N) 

41 offset = pid_m * N + n_offsets 

42 input_ptrs = input_ptr + offset 

43 mask = n_offsets < N 

44 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to( 

45 output_ptr.dtype.element_ty 

46 ) 

47 m = tl.max(inp, 0) 

48 e = tl.exp(inp - m) 

49 z = tl.sum(e, 0) 

50 out = e / z 

51 output_ptrs = output_ptr + offset 

52 tl.store(output_ptrs, out, mask=mask) 

53 else: 

54 m = tl.full([TILE_N], value=float("-inf"), dtype=tl.float32) 

55 z = tl.full([TILE_N], value=0.0, dtype=tl.float32) 

56 input_ptr += pid_m * N 

57 output_ptr += pid_m * N 

58 

59 previous_multiple = prev_multiple_of(N, TILE_N) 

60 for start_n in range(0, previous_multiple, TILE_N): 

61 n_offsets = start_n + tl.arange(0, TILE_N) 

62 inp = tl.load(input_ptr + n_offsets) 

63 m_new = tl.maximum(m, inp) 

64 # it is possible that there are -inf's in the input 

65 all_neg_inf = m_new == float("-inf") 

66 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new)) 

67 m = m_new 

68 # specialize the last iteration 

69 for start_n in range(previous_multiple, N, TILE_N): 

70 n_offsets = start_n + tl.arange(0, TILE_N) 

71 mask = n_offsets < N 

72 inp = tl.load(input_ptr + n_offsets, mask=mask, other=-float("inf")) 

73 m_new = tl.maximum(m, inp) 

74 all_neg_inf = m_new == float("-inf") 

75 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new)) 

76 m = m_new 

77 

78 m_reduced = tl.max(m, 0) 

79 z = tl.sum(z * tl.exp(m - m_reduced), 0) 

80 m = m_reduced 

81 

82 previous_multiple = prev_multiple_of(N, TILE_N) 

83 # specialize the first iteration 

84 for start_n in range(0, TILE_N, TILE_N): 

85 n_offsets = (previous_multiple - start_n) + tl.arange(0, TILE_N) 

86 mask = n_offsets < N 

87 inp = tl.load( 

88 input_ptr + n_offsets, 

89 mask=mask, 

90 other=-float("inf"), 

91 eviction_policy="evict_first", 

92 ) 

93 o = tl.exp(inp - m) / z 

94 tl.store(output_ptr + n_offsets, o, mask=mask) 

95 for start_n in range(TILE_N, N, TILE_N): 

96 n_offsets = (previous_multiple - start_n) + tl.arange(0, TILE_N) 

97 inp = tl.load(input_ptr + n_offsets, eviction_policy="evict_first") 

98 o = tl.exp(inp - m) / z 

99 tl.store(output_ptr + n_offsets, o) 

100 

101 

102# ------------------------ backward ------------------------------- 

103 

104 

105def softmax_backward_kernel_inner_heur_tile_m(args): 

106 return triton.cdiv(args["M"], 12) # cluster_num 

107 # return triton.next_power_of_2(triton.cdiv(args["M"], 12)) 

108 

109 

110def softmax_backward_kernel_inner_heru_tile_n(args): 

111 import builtins 

112 

113 return builtins.min(args["N"], 4096) 

114 # return builtins.min(triton.next_power_of_2(args["N"]), 8192) 

115 

116 

117def softmax_backward_kernel_inner_heur_one_tile_per_cta(args): 

118 return args["TILE_N"] >= args["N"] 

119 

120 

121@libentry() 

122# @triton.autotune( 

123# configs=runtime.get_tuned_config("softmax_inner"), 

124# key=["M", "N"], 

125# ) 

126# @triton.heuristics( 

127# values=runtime.get_heuristic_config("softmax_backward_inner"), 

128# ) 

129@triton.heuristics( 

130 values={ 

131 "TILE_M": softmax_backward_kernel_inner_heur_tile_m, 

132 "TILE_N": softmax_backward_kernel_inner_heru_tile_n, 

133 "ONE_TILE_PER_CTA": softmax_backward_kernel_inner_heur_one_tile_per_cta, 

134 }, 

135) 

136@triton.jit 

137def softmax_backward_kernel_inner( 

138 out_ptr, 

139 out_grad_ptr, 

140 in_grad_ptr, 

141 M, 

142 N, 

143 TILE_M: tl.constexpr, 

144 TILE_N: tl.constexpr, 

145 ONE_TILE_PER_CTA: tl.constexpr, 

146): 

147 pid_m = tle.program_id(0) 

148 m_offsets = pid_m * TILE_M + tl.arange(0, TILE_M) 

149 if ONE_TILE_PER_CTA: 

150 n_offsets = tl.arange(0, TILE_N) 

151 offsets = m_offsets[:, None] * N + n_offsets 

152 mask = (m_offsets[:, None] < M) & (n_offsets < N) 

153 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float64) 

154 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float64) 

155 scale = tl.sum(out_tile * out_grad_tile, 1) 

156 in_grad_tile = out_tile * (out_grad_tile - scale[:, None]) 

157 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask) 

158 else: 

159 scale = tl.zeros([TILE_M, TILE_N], dtype=tl.float64) 

160 

161 n_offsets = tl.arange(0, TILE_N) 

162 offsets = m_offsets[:, None] * N + n_offsets 

163 for _ in range(0, N, TILE_N): 

164 mask = (m_offsets[:, None] < M) & (n_offsets < N) 

165 out_tile = tl.load( 

166 out_ptr + offsets, mask=mask, eviction_policy="evict_last" 

167 ).to(tl.float64) 

168 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float64) 

169 scale += out_tile * out_grad_tile 

170 n_offsets += TILE_N 

171 offsets += TILE_N 

172 scale = tl.sum(scale, 1) # (TILE_M,) 

173 

174 n_offsets = tl.arange(0, TILE_N) 

175 offsets = m_offsets[:, None] * N + n_offsets 

176 for _ in range(0, N, TILE_N): 

177 mask = (m_offsets[:, None] < M) & (n_offsets < N) 

178 out_tile = tl.load( 

179 out_ptr + offsets, mask=mask, eviction_policy="evict_first" 

180 ) 

181 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float64) 

182 in_grad_tile = out_tile * (out_grad_tile - scale[:, None]).to(tl.float64) 

183 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask) 

184 n_offsets += TILE_N 

185 offsets += TILE_N 

186 

187 

188def softmax(self, dim, half_to_float=False): 

189 logger.debug("GEMS SOFTMAX") 

190 

191 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim" 

192 dim = dim % self.ndim 

193 M = 1 

194 N = self.shape[dim] 

195 for i in range(dim): 

196 M *= self.shape[i] # pre_dim 

197 self = self.contiguous() 

198 if half_to_float: 

199 dtype = torch.float32 

200 else: 

201 dtype = self.dtype 

202 out = torch.empty_like(self, dtype=dtype) 

203 K = self.numel() // M // N # post_dim 

204 

205 with torch_device_fn.device(self.device): 

206 if K > 1: 

207 # grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1) 

208 # 重新排列输入数据为 [M, K, N] 

209 inp_view = self.view(M, N, K).transpose(1, 2).contiguous() 

210 # 合并 M 和 K 维为 M' = M * K 

211 inp_reshaped = inp_view.view(M * K, N) 

212 if out.ndim == 3: 

213 m, n, k = out.shape 

214 elif out.ndim == 2: 

215 m, n = out.shape 

216 origin_dim = out.ndim 

217 

218 # 分配输出的视图 

219 out_view = out.view(M, N, K).transpose(1, 2).contiguous() 

220 out_reshaped = out_view.view(M * K, N) 

221 

222 grid = lambda meta: (M * K, 1, 1) 

223 

224 # 调用 Triton 前向内核 

225 softmax_kernel_inner[grid]( 

226 out_reshaped, 

227 inp_reshaped, 

228 M * K, 

229 N, 

230 buffer_size_limit=2048, 

231 is_use_mask_zero=True, 

232 ) 

233 

234 # 将输出恢复到原始布局 

235 # out_view.copy_(out_reshaped.view(M, K, N).transpose(1, 2)) 

236 if M == 1 and origin_dim == 2: 

237 out = out_reshaped.view(K, N).transpose(0, 1) 

238 elif M == 1 and origin_dim == 3: 

239 out = out_reshaped.transpose(0, 1).view(m, n, k) 

240 else: 

241 out = out_reshaped.view(m, k, n).transpose(1, 2) 

242 else: 

243 grid = (M, 1, 1) 

244 softmax_kernel_inner[grid]( 

245 out, 

246 self, 

247 M, 

248 N, 

249 buffer_size_limit=2048, 

250 isCloseVectorization=True, 

251 is_use_mask_zero=True, 

252 ) 

253 return out 

254 

255 

256def softmax_backward(grad_output, output, dim, input_dtype): 

257 logger.debug("GEMS SOFTMAX VJP") 

258 

259 assert dim >= -output.ndim and dim < output.ndim, "Invalid dim" 

260 dim = dim % output.ndim 

261 M = 1 

262 N = output.shape[dim] 

263 for i in range(dim): 

264 M *= output.shape[i] 

265 

266 grad_output = grad_output.contiguous() 

267 output = output.contiguous() 

268 in_grad = torch.empty_like(output, dtype=torch.float64) 

269 K = output.numel() // M // N 

270 

271 with torch_device_fn.device(in_grad.device): 

272 if K > 1: 

273 # how to use softmax_backward_kernel_inner? 

274 # some transpose and continuous 

275 out_grad_view = grad_output.view(M, N, K).transpose(1, 2).contiguous() 

276 out_view = output.view(M, N, K).transpose(1, 2).contiguous() 

277 # # 合并 M 和 K 维为 M' = M * K 

278 out_grad_reshaped = out_grad_view.view(M * K, N) 

279 out_reshaped = out_view.view(M * K, N) 

280 # 分配输入梯度的视图 

281 in_grad_view = in_grad.view(M, N, K).transpose(1, 2).contiguous() 

282 in_grad_reshaped = in_grad_view.view(M * K, N) 

283 

284 grid = lambda meta: (12, 1, 1) 

285 

286 # 调用 Triton 反向内核 

287 softmax_backward_kernel_inner[grid]( 

288 out_reshaped, 

289 out_grad_reshaped, 

290 in_grad_reshaped, 

291 M * K, 

292 N, 

293 buffer_size_limit=2048, 

294 isCloseUnrollControl=True, 

295 ) 

296 # 将输入梯度恢复到原始布局 

297 # in_grad_view.copy_(in_grad_reshaped.view(M, K, N).transpose(1, 2)) 

298 origin_dim = output.ndim 

299 if output.ndim == 3: 

300 m, n, k = output.shape 

301 elif output.ndim == 2: 

302 m, n = output.shape 

303 if M == 1 and origin_dim == 2: 

304 in_grad = in_grad_reshaped.view(K, N).transpose(0, 1) 

305 elif M == 1 and origin_dim == 3: 

306 in_grad = in_grad_reshaped.transpose(0, 1).view(m, n, k) 

307 else: 

308 in_grad = in_grad_reshaped.view(m, k, n).transpose(1, 2) 

309 else: 

310 grid = lambda meta: (12, 1, 1) 

311 

312 softmax_backward_kernel_inner[grid]( 

313 output, 

314 grad_output, 

315 in_grad, 

316 M, 

317 N, 

318 buffer_size_limit=2048, 

319 isCloseUnrollControl=True, 

320 ) 

321 return in_grad.to(input_dtype)