Coverage for src/flag_gems/runtime/backend/_mthreads/ops/log_softmax.py: 0%

242 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15@triton.jit 

16def prev_multiple_of(a, b): 

17 # the largest x<a that x%b ==0 

18 return tl.cdiv(a, b) * b - b 

19 

20 

21@libentry() 

22@triton.heuristics(runtime.get_heuristic_config("softmax_non_inner")) 

23@triton.jit 

24def log_softmax_kernel_non_inner( 

25 output_ptr, 

26 input_ptr, 

27 M, 

28 N, 

29 K, 

30 TILE_N: tl.constexpr, 

31 TILE_K: tl.constexpr, 

32 ONE_TILE_PER_CTA: tl.constexpr, 

33): 

34 pid_k = tle.program_id(1) 

35 pid_m = tle.program_id(0) 

36 

37 k_offsets = pid_k * TILE_K + tl.arange(0, TILE_K) 

38 

39 if ONE_TILE_PER_CTA: 

40 n_offsets = tl.arange(0, TILE_N) 

41 offset = pid_m * N * K + n_offsets[:, None] * K + k_offsets 

42 mask = (n_offsets[:, None] < N) & (k_offsets < K) 

43 input_ptrs = input_ptr + offset 

44 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32) 

45 m = tl.max(inp, 0) 

46 e = tl.exp(inp - m[None, :]) 

47 z = tl.sum(e, 0) 

48 out = inp - m[None, :] - tl.log(z)[None, :] 

49 output_ptrs = output_ptr + offset 

50 tl.store(output_ptrs, out, mask=mask) 

51 else: 

52 m = tl.full([TILE_N, TILE_K], value=float("-inf"), dtype=tl.float32) 

53 z = tl.full([TILE_N, TILE_K], value=0.0, dtype=tl.float32) 

54 

55 for start_n in range(0, N, TILE_N): 

56 n_offsets = start_n + tl.arange(0, TILE_N) 

57 offsets = pid_m * N * K + n_offsets[:, None] * K + k_offsets 

58 mask = (n_offsets[:, None] < N) & (k_offsets < K) 

59 inp = tl.load(input_ptr + offsets, mask=mask, other=-float("inf")).to( 

60 tl.float32 

61 ) 

62 m_new = tl.maximum(m, inp) 

63 all_neg_inf = m_new == float("-inf") 

64 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new)) 

65 m = m_new 

66 

67 m_reduced = tl.max(m, 0) # (TILE_K,) 

68 z = tl.sum(z * tl.exp(m - m_reduced[None, :]), 0) # (TILE_K, ) 

69 m = m_reduced 

70 

71 previous_multiple = prev_multiple_of(N, TILE_N) 

72 for start_n in range(0, N, TILE_N): 

73 n_offsets = (previous_multiple - start_n) + tl.arange(0, TILE_N) 

74 offsets = pid_m * N * K + n_offsets[:, None] * K + k_offsets 

75 mask = (n_offsets[:, None] < N) & (k_offsets[None, :] < K) 

76 inp = tl.load(input_ptr + offsets, mask=mask, other=-float("inf")).to( 

77 tl.float32 

78 ) 

79 o = inp - m[None, :] - tl.log(z)[None, :] 

80 tl.store(output_ptr + offsets, o, mask=mask) 

81 

82 

83def log_softmax_heur_tile_m(args): 

84 """Heuristic for TILE_M in inner kernel.""" 

85 M = args["M"] 

86 N = args["N"] 

87 if N <= 256: 

88 # For small N, process multiple rows 

89 if M >= 4096: 

90 return 8 

91 elif M >= 1024: 

92 return 4 

93 else: 

94 return 1 

95 elif N <= 1024: 

96 # For medium N 

97 if M >= 4096: 

98 return 4 

99 elif M >= 1024: 

100 return 2 

101 else: 

102 return 1 

103 else: 

104 return 1 

105 

106 

107def log_softmax_heur_tile_n_inner(args): 

108 """Heuristic for TILE_N in inner kernel.""" 

109 N = args["N"] 

110 M = args["M"] 

111 if N <= (32 * 1024): 

112 tile_n = triton.next_power_of_2(N) 

113 # For very small N, we might want larger TILE_N 

114 if N <= 32 and M > 1000: 

115 return 32 

116 # For medium-large N where we process 1 row per CTA, 

117 # use smaller TILE_N to enable loop for better register usage 

118 if N > 1024 and N <= 8192: 

119 return min(tile_n, 2048) 

120 return tile_n 

121 else: 

122 return 4096 

123 

124 

125def log_softmax_heur_one_tile_per_cta(args): 

126 return args["TILE_N"] >= args["N"] 

127 

128 

129def log_softmax_heur_num_warps_inner(args): 

130 tile_m = args["TILE_M"] 

131 tile_n = args["TILE_N"] 

132 tile_size = tile_m * tile_n 

133 if tile_size < 2048: 

134 return 4 

135 elif tile_size < 4096: 

136 return 8 

137 else: 

138 return 16 

139 

140 

141@libentry() 

142@triton.heuristics( 

143 { 

144 "TILE_M": log_softmax_heur_tile_m, 

145 "TILE_N": log_softmax_heur_tile_n_inner, 

146 "ONE_TILE_PER_CTA": log_softmax_heur_one_tile_per_cta, 

147 "num_warps": log_softmax_heur_num_warps_inner, 

148 } 

149) 

150@triton.jit 

151def log_softmax_kernel_inner( 

152 output_ptr, 

153 input_ptr, 

154 M, 

155 N, 

156 TILE_M: tl.constexpr, 

157 TILE_N: tl.constexpr, 

158 ONE_TILE_PER_CTA: tl.constexpr, 

159): 

160 pid_m = tle.program_id(0) 

161 m_offset = pid_m * TILE_M + tl.arange(0, TILE_M) 

162 

163 if ONE_TILE_PER_CTA: 

164 n_offsets = tl.arange(0, TILE_N) 

165 offset = m_offset[:, None] * N + n_offsets[None, :] 

166 mask = (m_offset[:, None] < M) & (n_offsets[None, :] < N) 

167 input_ptrs = input_ptr + offset 

168 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32) 

169 m = tl.max(inp, 1) 

170 e = tl.exp(inp - m[:, None]) 

171 z = tl.sum(e, 1) 

172 out = inp - m[:, None] - tl.log(z)[:, None] 

173 output_ptrs = output_ptr + offset 

174 tl.store(output_ptrs, out, mask=mask) 

175 else: 

176 m = tl.full([TILE_M, TILE_N], value=float("-inf"), dtype=tl.float32) 

177 z = tl.full([TILE_M, TILE_N], value=0.0, dtype=tl.float32) 

178 

179 for start_n in range(0, N, TILE_N): 

180 n_offsets = start_n + tl.arange(0, TILE_N) 

181 offset = m_offset[:, None] * N + n_offsets[None, :] 

182 mask = (m_offset[:, None] < M) & (n_offsets[None, :] < N) 

183 inp = tl.load(input_ptr + offset, mask=mask, other=-float("inf")).to( 

184 tl.float32 

185 ) 

186 m_new = tl.maximum(m, inp) 

187 all_neg_inf = m_new == float("-inf") 

188 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new)) 

189 m = m_new 

190 

191 m_reduced = tl.max(m, 1) 

192 z = tl.sum(z * tl.exp(m - m_reduced[:, None]), 1) 

193 m = m_reduced 

194 

195 for start_n in range(0, N, TILE_N): 

196 n_offsets = start_n + tl.arange(0, TILE_N) 

197 offset = m_offset[:, None] * N + n_offsets[None, :] 

198 mask = (m_offset[:, None] < M) & (n_offsets[None, :] < N) 

199 inp = tl.load(input_ptr + offset, mask=mask, other=-float("inf")).to( 

200 tl.float32 

201 ) 

202 out = inp - m[:, None] - tl.log(z)[:, None] 

203 tl.store(output_ptr + offset, out, mask=mask) 

204 

205 

206# ------------------------ backward ------------------------------- 

207@libentry() 

208@triton.autotune( 

209 configs=runtime.get_tuned_config("softmax_non_inner"), 

210 key=[ 

211 "M", 

212 "N", 

213 "K", 

214 ], 

215) 

216@triton.heuristics(runtime.get_heuristic_config("softmax_backward_non_inner")) 

217@triton.jit 

218def log_softmax_backward_kernel_non_inner( 

219 out_ptr, 

220 out_grad_ptr, 

221 in_grad_ptr, 

222 M, 

223 N, 

224 K, 

225 TILE_N: tl.constexpr, 

226 TILE_K: tl.constexpr, 

227 ONE_TILE_PER_CTA: tl.constexpr, 

228): 

229 pid_m = tle.program_id(0) 

230 pid_k = tle.program_id(1) 

231 offsets_k = pid_k * TILE_K + tl.arange(0, TILE_K) 

232 

233 if ONE_TILE_PER_CTA: 

234 offsets_n = tl.arange(0, TILE_N) 

235 offsets = pid_m * N * K + offsets_n[:, None] * K + offsets_k 

236 mask = (offsets_n < N)[:, None] & (offsets_k < K) 

237 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32) 

238 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32) 

239 scale = tl.sum(out_grad_tile, axis=0) 

240 in_grad_tile = out_grad_tile - tl.exp(out_tile) * scale[None, :] 

241 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask) 

242 else: 

243 offsets_n = tl.arange(0, TILE_N) 

244 offsets = pid_m * N * K + offsets_n[:, None] * K + offsets_k 

245 scale = tl.zeros([TILE_N, TILE_K], dtype=tl.float32) 

246 for _ in range(0, N, TILE_N): 

247 mask = (offsets_n < N)[:, None] & (offsets_k < K) 

248 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32) 

249 scale += out_grad_tile 

250 offsets_n += TILE_N 

251 offsets += TILE_N * K 

252 scale = tl.sum(scale, axis=0) # (TILE_K) 

253 

254 offsets_n = tl.arange(0, TILE_N) 

255 offsets = pid_m * N * K + offsets_n[:, None] * K + offsets_k 

256 for _ in range(0, N, TILE_N): 

257 mask = (offsets_n < N)[:, None] & (offsets_k < K) 

258 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32) 

259 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32) 

260 in_grad_tile = out_grad_tile - tl.exp(out_tile) * scale[None, :] 

261 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask) 

262 offsets_n += TILE_N 

263 offsets += TILE_N * K 

264 

265 

266@libentry() 

267@triton.autotune( 

268 configs=runtime.get_tuned_config("softmax_inner"), 

269 key=["M", "N"], 

270) 

271@triton.heuristics( 

272 values=runtime.get_heuristic_config("softmax_backward_inner"), 

273) 

274@triton.jit 

275def log_softmax_backward_kernel_inner( 

276 out_ptr, 

277 out_grad_ptr, 

278 in_grad_ptr, 

279 M, 

280 N, 

281 TILE_M: tl.constexpr, 

282 TILE_N: tl.constexpr, 

283 ONE_TILE_PER_CTA: tl.constexpr, 

284): 

285 pid_m = tle.program_id(0) 

286 m_offsets = pid_m * TILE_M + tl.arange(0, TILE_M) 

287 if ONE_TILE_PER_CTA: 

288 n_offsets = tl.arange(0, TILE_N) 

289 offsets = m_offsets[:, None] * N + n_offsets 

290 mask = (m_offsets[:, None] < M) & (n_offsets < N) 

291 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32) 

292 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32) 

293 scale = tl.sum(out_grad_tile, 1) 

294 in_grad_tile = out_grad_tile - tl.exp(out_tile) * scale[:, None] 

295 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask) 

296 else: 

297 scale = tl.zeros([TILE_M, TILE_N], dtype=tl.float32) 

298 

299 n_offsets = tl.arange(0, TILE_N) 

300 offsets = m_offsets[:, None] * N + n_offsets 

301 for _ in range(0, N, TILE_N): 

302 mask = (m_offsets[:, None] < M) & (n_offsets < N) 

303 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32) 

304 scale += out_grad_tile 

305 n_offsets += TILE_N 

306 offsets += TILE_N 

307 scale = tl.sum(scale, 1) # (TILE_M,) 

308 

309 n_offsets = tl.arange(0, TILE_N) 

310 offsets = m_offsets[:, None] * N + n_offsets 

311 for _ in range(0, N, TILE_N): 

312 mask = (m_offsets[:, None] < M) & (n_offsets < N) 

313 out_tile = tl.load( 

314 out_ptr + offsets, mask=mask, eviction_policy="evict_first" 

315 ).to(tl.float32) 

316 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32) 

317 in_grad_tile = out_grad_tile - tl.exp(out_tile) * scale[:, None] 

318 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask) 

319 n_offsets += TILE_N 

320 offsets += TILE_N 

321 

322 

323def log_softmax(self, dim, half_to_float=False): 

324 logger.debug("GEMS_MTHREADS LOG_SOFTMAX") 

325 

326 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim" 

327 dim = dim % self.ndim 

328 M = 1 

329 N = self.shape[dim] 

330 for i in range(dim): 

331 M *= self.shape[i] # pre_dim 

332 self = self.contiguous() 

333 if half_to_float: 

334 dtype = torch.float32 

335 else: 

336 dtype = self.dtype 

337 out = torch.empty_like(self, dtype=dtype) 

338 K = self.numel() // M // N # post_dim 

339 

340 with torch_device_fn.device(self.device): 

341 if K > 1: 

342 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1) 

343 log_softmax_kernel_non_inner[grid]( 

344 out, 

345 self, 

346 M, 

347 N, 

348 K, 

349 ) 

350 else: 

351 grid = lambda meta: (triton.cdiv(M, meta["TILE_M"]), 1, 1) 

352 log_softmax_kernel_inner[grid]( 

353 out, 

354 self, 

355 M, 

356 N, 

357 ) 

358 return out 

359 

360 

361def log_softmax_backward(grad_output, output, dim, input_dtype): 

362 logger.debug("GEMS_MTHREADS LOG_SOFTMAX BACKWARD") 

363 

364 assert dim >= -output.ndim and dim < output.ndim, "Invalid dim" 

365 dim = dim % output.ndim 

366 M = 1 

367 N = output.shape[dim] 

368 for i in range(dim): 

369 M *= output.shape[i] 

370 

371 grad_output = grad_output.contiguous() 

372 in_grad = torch.empty_like(output, dtype=input_dtype) 

373 K = output.numel() // M // N 

374 

375 with torch_device_fn.device(in_grad.device): 

376 if K > 1: 

377 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1) 

378 log_softmax_backward_kernel_non_inner[grid]( 

379 output, 

380 grad_output, 

381 in_grad, 

382 M, 

383 N, 

384 K, 

385 ) 

386 else: 

387 grid = lambda meta: (triton.cdiv(M, meta["TILE_M"]), 1, 1) 

388 log_softmax_backward_kernel_inner[grid]( 

389 output, 

390 grad_output, 

391 in_grad, 

392 M, 

393 N, 

394 ) 

395 return in_grad