Coverage for src/flag_gems/ops/softmax.py: 33%

228 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.ops.zeros import zero_ 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13logger = logging.getLogger(__name__) 

14 

15 

16@libentry() 

17@triton.heuristics(runtime.get_heuristic_config("softmax_non_inner")) 

18@triton.jit 

19def softmax_kernel_non_inner( 

20 output_ptr, 

21 input_ptr, 

22 M, 

23 N, 

24 K, 

25 TILE_N: tl.constexpr, 

26 TILE_K: tl.constexpr, 

27 ONE_TILE_PER_CTA: tl.constexpr, 

28): 

29 pid_k = tle.program_id(1) 

30 pid_m = tle.program_id(0) 

31 

32 k_offsets = pid_k * TILE_K + tl.arange(0, TILE_K) 

33 

34 if ONE_TILE_PER_CTA: 

35 n_offsets = tl.arange(0, TILE_N) 

36 offset = pid_m * N * K + n_offsets[:, None] * K + k_offsets 

37 mask = (n_offsets[:, None] < N) & (k_offsets < K) 

38 input_ptrs = input_ptr + offset 

39 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")) 

40 m = tl.max(inp, 0) 

41 e = tl.exp(inp - m[None, :]) 

42 z = tl.sum(e, 0) 

43 out = e / z 

44 output_ptrs = output_ptr + offset 

45 tl.store(output_ptrs, out, mask=mask) 

46 else: 

47 m = tl.full([TILE_N, TILE_K], value=float("-inf"), dtype=tl.float32) 

48 z = tl.full([TILE_N, TILE_K], value=0.0, dtype=tl.float32) 

49 

50 # specialization does not improve performance inn this example, as tested 

51 for start_n in range(0, N, TILE_N): 

52 n_offsets = start_n + tl.arange(0, TILE_N) 

53 offsets = pid_m * N * K + n_offsets[:, None] * K + k_offsets 

54 mask = (n_offsets[:, None] < N) & (k_offsets < K) 

55 inp = tl.load(input_ptr + offsets, mask=mask, other=-float("inf")) 

56 m_new = tl.maximum(m, inp) 

57 all_neg_inf = m_new == float("-inf") 

58 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new)) 

59 m = m_new 

60 

61 m_reduced = tl.max(m, 0) # (TILE_K,) 

62 z = tl.sum(z * tl.exp(m - m_reduced[None, :]), 0) # (TILE_K, ) 

63 m = m_reduced 

64 

65 # specialization does not improve performance inn this example, as tested 

66 previous_multiple = prev_multiple_of(N, TILE_N) 

67 for start_n in range(0, N, TILE_N): 

68 n_offsets = (previous_multiple - start_n) + tl.arange(0, TILE_N) 

69 offsets = pid_m * N * K + n_offsets[:, None] * K + k_offsets 

70 mask = (n_offsets[:, None] < N) & (k_offsets[None, :] < K) 

71 inp = tl.load(input_ptr + offsets, mask=mask, other=-float("inf")) 

72 o = tl.exp(inp - m[None, :]) / z[None, :] 

73 tl.store(output_ptr + offsets, o, mask=mask) 

74 

75 

76@triton.jit 

77def next_multiple_of(a, b): 

78 # the smallest x>=a that x%b ==0 

79 return tl.cidv(a, b) * b 

80 

81 

82@triton.jit 

83def prev_multiple_of(a, b): 

84 # the largest x<a that x%b ==0 

85 return tl.cdiv(a, b) * b - b 

86 

87 

88@libentry() 

89@triton.heuristics(runtime.get_heuristic_config("softmax_inner")) 

90@triton.jit 

91def softmax_kernel_inner( 

92 output_ptr, 

93 input_ptr, 

94 M, 

95 N, 

96 TILE_N: tl.constexpr, 

97 ONE_TILE_PER_CTA: tl.constexpr, 

98): 

99 pid_m = tle.program_id(0) 

100 if ONE_TILE_PER_CTA: 

101 n_offsets = tl.arange(0, TILE_N) 

102 offset = pid_m * N + n_offsets 

103 input_ptrs = input_ptr + offset 

104 mask = n_offsets < N 

105 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to( 

106 output_ptr.dtype.element_ty 

107 ) 

108 m = tl.max(inp, 0) 

109 e = tl.exp(inp - m) 

110 z = tl.sum(e, 0) 

111 out = e / z 

112 output_ptrs = output_ptr + offset 

113 tl.store(output_ptrs, out, mask=mask) 

114 else: 

115 m = tl.full([TILE_N], value=float("-inf"), dtype=tl.float32) 

116 z = tl.full([TILE_N], value=0.0, dtype=tl.float32) 

117 input_ptr += pid_m * N 

118 output_ptr += pid_m * N 

119 

120 previous_multiple = prev_multiple_of(N, TILE_N) 

121 for start_n in range(0, previous_multiple, TILE_N): 

122 n_offsets = start_n + tl.arange(0, TILE_N) 

123 inp = tl.load(input_ptr + n_offsets) 

124 m_new = tl.maximum(m, inp) 

125 # it is possible that there are -inf's in the input 

126 all_neg_inf = m_new == float("-inf") 

127 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new)) 

128 m = m_new 

129 # specialize the last iteration 

130 for start_n in range(previous_multiple, N, TILE_N): 

131 n_offsets = start_n + tl.arange(0, TILE_N) 

132 mask = n_offsets < N 

133 inp = tl.load(input_ptr + n_offsets, mask=mask, other=-float("inf")) 

134 m_new = tl.maximum(m, inp) 

135 all_neg_inf = m_new == float("-inf") 

136 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new)) 

137 m = m_new 

138 

139 m_reduced = tl.max(m, 0) 

140 z = tl.sum(z * tl.exp(m - m_reduced), 0) 

141 m = m_reduced 

142 

143 previous_multiple = prev_multiple_of(N, TILE_N) 

144 # specialize the first iteration 

145 for start_n in range(0, TILE_N, TILE_N): 

146 n_offsets = (previous_multiple - start_n) + tl.arange(0, TILE_N) 

147 mask = n_offsets < N 

148 inp = tl.load( 

149 input_ptr + n_offsets, 

150 mask=mask, 

151 other=-float("inf"), 

152 eviction_policy="evict_first", 

153 ) 

154 o = tl.exp(inp - m) / z 

155 tl.store(output_ptr + n_offsets, o, mask=mask) 

156 for start_n in range(TILE_N, N, TILE_N): 

157 n_offsets = (previous_multiple - start_n) + tl.arange(0, TILE_N) 

158 inp = tl.load(input_ptr + n_offsets, eviction_policy="evict_first") 

159 o = tl.exp(inp - m) / z 

160 tl.store(output_ptr + n_offsets, o) 

161 

162 

163# ------------------------ backward ------------------------------- 

164@libentry() 

165@triton.autotune( 

166 configs=runtime.get_tuned_config("softmax_non_inner"), 

167 key=[ 

168 "M", 

169 "N", 

170 "K", 

171 ], 

172) 

173@triton.heuristics(runtime.get_heuristic_config("softmax_backward_non_inner")) 

174@triton.jit 

175def softmax_backward_kernel_non_inner( 

176 out_ptr, 

177 out_grad_ptr, 

178 in_grad_ptr, 

179 M, 

180 N, 

181 K, 

182 TILE_N: tl.constexpr, 

183 TILE_K: tl.constexpr, 

184 ONE_TILE_PER_CTA: tl.constexpr, 

185): 

186 pid_m = tle.program_id(0) 

187 pid_k = tle.program_id(1) 

188 offsets_k = pid_k * TILE_K + tl.arange(0, TILE_K) 

189 

190 if ONE_TILE_PER_CTA: 

191 offsets_n = tl.arange(0, TILE_N) 

192 offsets = pid_m * N * K + offsets_n[:, None] * K + offsets_k 

193 mask = (offsets_n < N)[:, None] & (offsets_k < K) 

194 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32) 

195 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32) 

196 scale = tl.sum(out_tile * out_grad_tile, axis=0) 

197 in_grad_tile = out_tile * (out_grad_tile - scale[None, :]) 

198 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask) 

199 else: 

200 offsets_n = tl.arange(0, TILE_N) 

201 offsets = pid_m * N * K + offsets_n[:, None] * K + offsets_k 

202 scale = tl.zeros([TILE_N, TILE_K], dtype=tl.float32) 

203 for _ in range(0, N, TILE_N): 

204 mask = (offsets_n < N)[:, None] & (offsets_k < K) 

205 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32) 

206 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32) 

207 scale += out_tile * out_grad_tile 

208 offsets_n += TILE_N 

209 offsets += TILE_N * K 

210 scale = tl.sum(scale, axis=0) # (TILE_K) 

211 

212 offsets_n = tl.arange(0, TILE_N) 

213 offsets = pid_m * N * K + offsets_n[:, None] * K + offsets_k 

214 for _ in range(0, N, TILE_N): 

215 mask = (offsets_n < N)[:, None] & (offsets_k < K) 

216 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32) 

217 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32) 

218 in_grad_tile = out_tile * (out_grad_tile - scale[None, :]) 

219 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask) 

220 offsets_n += TILE_N 

221 offsets += TILE_N * K 

222 

223 

224@libentry() 

225@triton.autotune( 

226 configs=runtime.get_tuned_config("softmax_inner"), 

227 key=["M", "N"], 

228) 

229@triton.heuristics( 

230 values=runtime.get_heuristic_config("softmax_backward_inner"), 

231) 

232@triton.jit 

233def softmax_backward_kernel_inner( 

234 out_ptr, 

235 out_grad_ptr, 

236 in_grad_ptr, 

237 M, 

238 N, 

239 TILE_M: tl.constexpr, 

240 TILE_N: tl.constexpr, 

241 ONE_TILE_PER_CTA: tl.constexpr, 

242): 

243 pid_m = tle.program_id(0) 

244 m_offsets = pid_m * TILE_M + tl.arange(0, TILE_M) 

245 if ONE_TILE_PER_CTA: 

246 n_offsets = tl.arange(0, TILE_N) 

247 offsets = m_offsets[:, None] * N + n_offsets 

248 mask = (m_offsets[:, None] < M) & (n_offsets < N) 

249 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32) 

250 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32) 

251 scale = tl.sum(out_tile * out_grad_tile, 1) 

252 in_grad_tile = out_tile * (out_grad_tile - scale[:, None]) 

253 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask) 

254 else: 

255 scale = tl.zeros([TILE_M, TILE_N], dtype=tl.float32) 

256 

257 n_offsets = tl.arange(0, TILE_N) 

258 offsets = m_offsets[:, None] * N + n_offsets 

259 for _ in range(0, N, TILE_N): 

260 mask = (m_offsets[:, None] < M) & (n_offsets < N) 

261 out_tile = tl.load( 

262 out_ptr + offsets, mask=mask, eviction_policy="evict_last" 

263 ).to(tl.float32) 

264 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32) 

265 scale += out_tile * out_grad_tile 

266 n_offsets += TILE_N 

267 offsets += TILE_N 

268 scale = tl.sum(scale, 1) # (TILE_M,) 

269 

270 n_offsets = tl.arange(0, TILE_N) 

271 offsets = m_offsets[:, None] * N + n_offsets 

272 for _ in range(0, N, TILE_N): 

273 mask = (m_offsets[:, None] < M) & (n_offsets < N) 

274 out_tile = tl.load( 

275 out_ptr + offsets, mask=mask, eviction_policy="evict_first" 

276 ).to(tl.float32) 

277 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32) 

278 in_grad_tile = out_tile * (out_grad_tile - scale[:, None]) 

279 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask) 

280 n_offsets += TILE_N 

281 offsets += TILE_N 

282 

283 

284def softmax(self, dim, half_to_float=False): 

285 logger.debug("GEMS SOFTMAX") 

286 

287 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim" 

288 

289 # special handling for dim = 0 and empty tensor 

290 if self.numel() == 0: 

291 # empty tensor, return the same shape with 1's 

292 out_shape = list(self.shape) 

293 out = torch.empty(out_shape, dtype=self.dtype, device=self.device) 

294 zero_(out) 

295 return out 

296 

297 dim = dim % self.ndim 

298 M = 1 

299 N = self.shape[dim] 

300 for i in range(dim): 

301 M *= self.shape[i] # pre_dim 

302 self = self.contiguous() 

303 if half_to_float: 

304 dtype = torch.float32 

305 else: 

306 dtype = self.dtype 

307 out = torch.empty_like(self, dtype=dtype) 

308 K = self.numel() // M // N # post_dim 

309 

310 with torch_device_fn.device(self.device): 

311 if K > 1: 

312 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1) 

313 softmax_kernel_non_inner[grid]( 

314 out, 

315 self, 

316 M, 

317 N, 

318 K, 

319 ) 

320 else: 

321 grid = (M, 1, 1) 

322 softmax_kernel_inner[grid]( 

323 out, 

324 self, 

325 M, 

326 N, 

327 ) 

328 return out 

329 

330 

331def softmax_backward(grad_output, output, dim, input_dtype): 

332 logger.debug("GEMS SOFTMAX VJP") 

333 

334 assert dim >= -output.ndim and dim < output.ndim, "Invalid dim" 

335 dim = dim % output.ndim 

336 M = 1 

337 N = output.shape[dim] 

338 for i in range(dim): 

339 M *= output.shape[i] 

340 

341 grad_output = grad_output.contiguous() 

342 in_grad = torch.empty_like(output, dtype=input_dtype) 

343 K = output.numel() // M // N 

344 

345 with torch_device_fn.device(in_grad.device): 

346 if K > 1: 

347 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1) 

348 softmax_backward_kernel_non_inner[grid]( 

349 output, 

350 grad_output, 

351 in_grad, 

352 M, 

353 N, 

354 K, 

355 ) 

356 else: 

357 grid = lambda meta: (triton.cdiv(M, meta["TILE_M"]), 1, 1) 

358 softmax_backward_kernel_inner[grid]( 

359 output, 

360 grad_output, 

361 in_grad, 

362 M, 

363 N, 

364 ) 

365 return in_grad