Coverage for src/flag_gems/ops/softmax.py: 31%

222 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-12 02:21 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15@libentry() 

16@triton.heuristics(runtime.get_heuristic_config("softmax_non_inner")) 

17@triton.jit 

18def softmax_kernel_non_inner( 

19 output_ptr, 

20 input_ptr, 

21 M, 

22 N, 

23 K, 

24 TILE_N: tl.constexpr, 

25 TILE_K: tl.constexpr, 

26 ONE_TILE_PER_CTA: tl.constexpr, 

27): 

28 pid_k = tle.program_id(1) 

29 pid_m = tle.program_id(0) 

30 

31 k_offsets = pid_k * TILE_K + tl.arange(0, TILE_K) 

32 

33 if ONE_TILE_PER_CTA: 

34 n_offsets = tl.arange(0, TILE_N) 

35 offset = pid_m * N * K + n_offsets[:, None] * K + k_offsets 

36 mask = (n_offsets[:, None] < N) & (k_offsets < K) 

37 input_ptrs = input_ptr + offset 

38 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")) 

39 m = tl.max(inp, 0) 

40 e = tl.exp(inp - m[None, :]) 

41 z = tl.sum(e, 0) 

42 out = e / z 

43 output_ptrs = output_ptr + offset 

44 tl.store(output_ptrs, out, mask=mask) 

45 else: 

46 m = tl.full([TILE_N, TILE_K], value=float("-inf"), dtype=tl.float32) 

47 z = tl.full([TILE_N, TILE_K], value=0.0, dtype=tl.float32) 

48 

49 # specialization does not improve performance inn this example, as tested 

50 for start_n in range(0, N, TILE_N): 

51 n_offsets = start_n + tl.arange(0, TILE_N) 

52 offsets = pid_m * N * K + n_offsets[:, None] * K + k_offsets 

53 mask = (n_offsets[:, None] < N) & (k_offsets < K) 

54 inp = tl.load(input_ptr + offsets, mask=mask, other=-float("inf")) 

55 m_new = tl.maximum(m, inp) 

56 all_neg_inf = m_new == float("-inf") 

57 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new)) 

58 m = m_new 

59 

60 m_reduced = tl.max(m, 0) # (TILE_K,) 

61 z = tl.sum(z * tl.exp(m - m_reduced[None, :]), 0) # (TILE_K, ) 

62 m = m_reduced 

63 

64 # specialization does not improve performance inn this example, as tested 

65 previous_multiple = prev_multiple_of(N, TILE_N) 

66 for start_n in range(0, N, TILE_N): 

67 n_offsets = (previous_multiple - start_n) + tl.arange(0, TILE_N) 

68 offsets = pid_m * N * K + n_offsets[:, None] * K + k_offsets 

69 mask = (n_offsets[:, None] < N) & (k_offsets[None, :] < K) 

70 inp = tl.load(input_ptr + offsets, mask=mask, other=-float("inf")) 

71 o = tl.exp(inp - m[None, :]) / z[None, :] 

72 tl.store(output_ptr + offsets, o, mask=mask) 

73 

74 

75@triton.jit 

76def next_multiple_of(a, b): 

77 # the smallest x>=a that x%b ==0 

78 return tl.cidv(a, b) * b 

79 

80 

81@triton.jit 

82def prev_multiple_of(a, b): 

83 # the largest x<a that x%b ==0 

84 return tl.cdiv(a, b) * b - b 

85 

86 

87@libentry() 

88@triton.heuristics(runtime.get_heuristic_config("softmax_inner")) 

89@triton.jit 

90def softmax_kernel_inner( 

91 output_ptr, 

92 input_ptr, 

93 M, 

94 N, 

95 TILE_N: tl.constexpr, 

96 ONE_TILE_PER_CTA: tl.constexpr, 

97): 

98 pid_m = tle.program_id(0) 

99 if ONE_TILE_PER_CTA: 

100 n_offsets = tl.arange(0, TILE_N) 

101 offset = pid_m * N + n_offsets 

102 input_ptrs = input_ptr + offset 

103 mask = n_offsets < N 

104 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to( 

105 output_ptr.dtype.element_ty 

106 ) 

107 m = tl.max(inp, 0) 

108 e = tl.exp(inp - m) 

109 z = tl.sum(e, 0) 

110 out = e / z 

111 output_ptrs = output_ptr + offset 

112 tl.store(output_ptrs, out, mask=mask) 

113 else: 

114 m = tl.full([TILE_N], value=float("-inf"), dtype=tl.float32) 

115 z = tl.full([TILE_N], value=0.0, dtype=tl.float32) 

116 input_ptr += pid_m * N 

117 output_ptr += pid_m * N 

118 

119 previous_multiple = prev_multiple_of(N, TILE_N) 

120 for start_n in range(0, previous_multiple, TILE_N): 

121 n_offsets = start_n + tl.arange(0, TILE_N) 

122 inp = tl.load(input_ptr + n_offsets) 

123 m_new = tl.maximum(m, inp) 

124 # it is possible that there are -inf's in the input 

125 all_neg_inf = m_new == float("-inf") 

126 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new)) 

127 m = m_new 

128 # specialize the last iteration 

129 for start_n in range(previous_multiple, N, TILE_N): 

130 n_offsets = start_n + tl.arange(0, TILE_N) 

131 mask = n_offsets < N 

132 inp = tl.load(input_ptr + n_offsets, mask=mask, other=-float("inf")) 

133 m_new = tl.maximum(m, inp) 

134 all_neg_inf = m_new == float("-inf") 

135 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new)) 

136 m = m_new 

137 

138 m_reduced = tl.max(m, 0) 

139 z = tl.sum(z * tl.exp(m - m_reduced), 0) 

140 m = m_reduced 

141 

142 previous_multiple = prev_multiple_of(N, TILE_N) 

143 # specialize the first iteration 

144 for start_n in range(0, TILE_N, TILE_N): 

145 n_offsets = (previous_multiple - start_n) + tl.arange(0, TILE_N) 

146 mask = n_offsets < N 

147 inp = tl.load( 

148 input_ptr + n_offsets, 

149 mask=mask, 

150 other=-float("inf"), 

151 eviction_policy="evict_first", 

152 ) 

153 o = tl.exp(inp - m) / z 

154 tl.store(output_ptr + n_offsets, o, mask=mask) 

155 for start_n in range(TILE_N, N, TILE_N): 

156 n_offsets = (previous_multiple - start_n) + tl.arange(0, TILE_N) 

157 inp = tl.load(input_ptr + n_offsets, eviction_policy="evict_first") 

158 o = tl.exp(inp - m) / z 

159 tl.store(output_ptr + n_offsets, o) 

160 

161 

162# ------------------------ backward ------------------------------- 

163@libentry() 

164@triton.autotune( 

165 configs=runtime.get_tuned_config("softmax_non_inner"), 

166 key=[ 

167 "M", 

168 "N", 

169 "K", 

170 ], 

171) 

172@triton.heuristics(runtime.get_heuristic_config("softmax_backward_non_inner")) 

173@triton.jit 

174def softmax_backward_kernel_non_inner( 

175 out_ptr, 

176 out_grad_ptr, 

177 in_grad_ptr, 

178 M, 

179 N, 

180 K, 

181 TILE_N: tl.constexpr, 

182 TILE_K: tl.constexpr, 

183 ONE_TILE_PER_CTA: tl.constexpr, 

184): 

185 pid_m = tle.program_id(0) 

186 pid_k = tle.program_id(1) 

187 offsets_k = pid_k * TILE_K + tl.arange(0, TILE_K) 

188 

189 if ONE_TILE_PER_CTA: 

190 offsets_n = tl.arange(0, TILE_N) 

191 offsets = pid_m * N * K + offsets_n[:, None] * K + offsets_k 

192 mask = (offsets_n < N)[:, None] & (offsets_k < K) 

193 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32) 

194 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32) 

195 scale = tl.sum(out_tile * out_grad_tile, axis=0) 

196 in_grad_tile = out_tile * (out_grad_tile - scale[None, :]) 

197 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask) 

198 else: 

199 offsets_n = tl.arange(0, TILE_N) 

200 offsets = pid_m * N * K + offsets_n[:, None] * K + offsets_k 

201 scale = tl.zeros([TILE_N, TILE_K], dtype=tl.float32) 

202 for _ in range(0, N, TILE_N): 

203 mask = (offsets_n < N)[:, None] & (offsets_k < K) 

204 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32) 

205 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32) 

206 scale += out_tile * out_grad_tile 

207 offsets_n += TILE_N 

208 offsets += TILE_N * K 

209 scale = tl.sum(scale, axis=0) # (TILE_K) 

210 

211 offsets_n = tl.arange(0, TILE_N) 

212 offsets = pid_m * N * K + offsets_n[:, None] * K + offsets_k 

213 for _ in range(0, N, TILE_N): 

214 mask = (offsets_n < N)[:, None] & (offsets_k < K) 

215 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32) 

216 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32) 

217 in_grad_tile = out_tile * (out_grad_tile - scale[None, :]) 

218 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask) 

219 offsets_n += TILE_N 

220 offsets += TILE_N * K 

221 

222 

223@libentry() 

224@triton.autotune( 

225 configs=runtime.get_tuned_config("softmax_inner"), 

226 key=["M", "N"], 

227) 

228@triton.heuristics( 

229 values=runtime.get_heuristic_config("softmax_backward_inner"), 

230) 

231@triton.jit 

232def softmax_backward_kernel_inner( 

233 out_ptr, 

234 out_grad_ptr, 

235 in_grad_ptr, 

236 M, 

237 N, 

238 TILE_M: tl.constexpr, 

239 TILE_N: tl.constexpr, 

240 ONE_TILE_PER_CTA: tl.constexpr, 

241): 

242 pid_m = tle.program_id(0) 

243 m_offsets = pid_m * TILE_M + tl.arange(0, TILE_M) 

244 if ONE_TILE_PER_CTA: 

245 n_offsets = tl.arange(0, TILE_N) 

246 offsets = m_offsets[:, None] * N + n_offsets 

247 mask = (m_offsets[:, None] < M) & (n_offsets < N) 

248 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32) 

249 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32) 

250 scale = tl.sum(out_tile * out_grad_tile, 1) 

251 in_grad_tile = out_tile * (out_grad_tile - scale[:, None]) 

252 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask) 

253 else: 

254 scale = tl.zeros([TILE_M, TILE_N], dtype=tl.float32) 

255 

256 n_offsets = tl.arange(0, TILE_N) 

257 offsets = m_offsets[:, None] * N + n_offsets 

258 for _ in range(0, N, TILE_N): 

259 mask = (m_offsets[:, None] < M) & (n_offsets < N) 

260 out_tile = tl.load( 

261 out_ptr + offsets, mask=mask, eviction_policy="evict_last" 

262 ).to(tl.float32) 

263 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32) 

264 scale += out_tile * out_grad_tile 

265 n_offsets += TILE_N 

266 offsets += TILE_N 

267 scale = tl.sum(scale, 1) # (TILE_M,) 

268 

269 n_offsets = tl.arange(0, TILE_N) 

270 offsets = m_offsets[:, None] * N + n_offsets 

271 for _ in range(0, N, TILE_N): 

272 mask = (m_offsets[:, None] < M) & (n_offsets < N) 

273 out_tile = tl.load( 

274 out_ptr + offsets, mask=mask, eviction_policy="evict_first" 

275 ).to(tl.float32) 

276 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32) 

277 in_grad_tile = out_tile * (out_grad_tile - scale[:, None]) 

278 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask) 

279 n_offsets += TILE_N 

280 offsets += TILE_N 

281 

282 

283def softmax(self, dim, half_to_float=False): 

284 logger.debug("GEMS SOFTMAX") 

285 

286 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim" 

287 dim = dim % self.ndim 

288 M = 1 

289 N = self.shape[dim] 

290 for i in range(dim): 

291 M *= self.shape[i] # pre_dim 

292 self = self.contiguous() 

293 if half_to_float: 

294 dtype = torch.float32 

295 else: 

296 dtype = self.dtype 

297 out = torch.empty_like(self, dtype=dtype) 

298 K = self.numel() // M // N # post_dim 

299 

300 with torch_device_fn.device(self.device): 

301 if K > 1: 

302 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1) 

303 softmax_kernel_non_inner[grid]( 

304 out, 

305 self, 

306 M, 

307 N, 

308 K, 

309 ) 

310 else: 

311 grid = (M, 1, 1) 

312 softmax_kernel_inner[grid]( 

313 out, 

314 self, 

315 M, 

316 N, 

317 ) 

318 return out 

319 

320 

321def softmax_backward(grad_output, output, dim, input_dtype): 

322 logger.debug("GEMS SOFTMAX VJP") 

323 

324 assert dim >= -output.ndim and dim < output.ndim, "Invalid dim" 

325 dim = dim % output.ndim 

326 M = 1 

327 N = output.shape[dim] 

328 for i in range(dim): 

329 M *= output.shape[i] 

330 

331 grad_output = grad_output.contiguous() 

332 in_grad = torch.empty_like(output, dtype=input_dtype) 

333 K = output.numel() // M // N 

334 

335 with torch_device_fn.device(in_grad.device): 

336 if K > 1: 

337 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1) 

338 softmax_backward_kernel_non_inner[grid]( 

339 output, 

340 grad_output, 

341 in_grad, 

342 M, 

343 N, 

344 K, 

345 ) 

346 else: 

347 grid = lambda meta: (triton.cdiv(M, meta["TILE_M"]), 1, 1) 

348 softmax_backward_kernel_inner[grid]( 

349 output, 

350 grad_output, 

351 in_grad, 

352 M, 

353 N, 

354 ) 

355 return in_grad