Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/softmax.py: 0%

186 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-30 03:43 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.ops.zeros import zero_ 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

14 

15 

16@triton.jit 

17def next_multiple_of(a, b): 

18 # the smallest x>=a that x%b ==0 

19 return tl.cdiv(a, b) * b 

20 

21 

22@triton.jit 

23def prev_multiple_of(a, b): 

24 # the largest x<a that x%b ==0 

25 return tl.cdiv(a, b) * b - b 

26 

27 

28@libentry() 

29@triton.heuristics(runtime.get_heuristic_config("softmax_inner")) 

30@triton.jit 

31def softmax_kernel_inner( 

32 output_ptr, 

33 input_ptr, 

34 M, 

35 N, 

36 TILE_N: tl.constexpr, 

37 ONE_TILE_PER_CTA: tl.constexpr, 

38): 

39 pid_m = tle.program_id(0) 

40 if ONE_TILE_PER_CTA: 

41 n_offsets = tl.arange(0, TILE_N) 

42 offset = pid_m * N + n_offsets 

43 input_ptrs = input_ptr + offset 

44 mask = n_offsets < N 

45 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to( 

46 output_ptr.dtype.element_ty 

47 ) 

48 m = tl.max(inp, 0) 

49 e = tl.exp(inp - m) 

50 z = tl.sum(e, 0) 

51 out = e / z 

52 output_ptrs = output_ptr + offset 

53 tl.store(output_ptrs, out, mask=mask) 

54 else: 

55 m = tl.full([TILE_N], value=float("-inf"), dtype=tl.float32) 

56 z = tl.full([TILE_N], value=0.0, dtype=tl.float32) 

57 input_ptr += pid_m * N 

58 output_ptr += pid_m * N 

59 

60 previous_multiple = prev_multiple_of(N, TILE_N) 

61 for start_n in range(0, previous_multiple, TILE_N): 

62 n_offsets = start_n + tl.arange(0, TILE_N) 

63 inp = tl.load(input_ptr + n_offsets) 

64 m_new = tl.maximum(m, inp) 

65 # it is possible that there are -inf's in the input 

66 all_neg_inf = m_new == float("-inf") 

67 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new)) 

68 m = m_new 

69 # specialize the last iteration 

70 for start_n in range(previous_multiple, N, TILE_N): 

71 n_offsets = start_n + tl.arange(0, TILE_N) 

72 mask = n_offsets < N 

73 inp = tl.load(input_ptr + n_offsets, mask=mask, other=-float("inf")) 

74 m_new = tl.maximum(m, inp) 

75 all_neg_inf = m_new == float("-inf") 

76 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new)) 

77 m = m_new 

78 

79 m_reduced = tl.max(m, 0) 

80 z = tl.sum(z * tl.exp(m - m_reduced), 0) 

81 m = m_reduced 

82 

83 previous_multiple = prev_multiple_of(N, TILE_N) 

84 # specialize the first iteration 

85 for start_n in range(0, TILE_N, TILE_N): 

86 n_offsets = (previous_multiple - start_n) + tl.arange(0, TILE_N) 

87 mask = n_offsets < N 

88 inp = tl.load( 

89 input_ptr + n_offsets, 

90 mask=mask, 

91 other=-float("inf"), 

92 eviction_policy="evict_first", 

93 ) 

94 o = tl.exp(inp - m) / z 

95 tl.store(output_ptr + n_offsets, o, mask=mask) 

96 for start_n in range(TILE_N, N, TILE_N): 

97 n_offsets = (previous_multiple - start_n) + tl.arange(0, TILE_N) 

98 inp = tl.load(input_ptr + n_offsets, eviction_policy="evict_first") 

99 o = tl.exp(inp - m) / z 

100 tl.store(output_ptr + n_offsets, o) 

101 

102 

103# ------------------------ backward ------------------------------- 

104 

105 

106def softmax_backward_kernel_inner_heur_tile_m(args): 

107 return triton.cdiv(args["M"], 12) # cluster_num 

108 # return triton.next_power_of_2(triton.cdiv(args["M"], 12)) 

109 

110 

111def softmax_backward_kernel_inner_heru_tile_n(args): 

112 import builtins 

113 

114 return builtins.min(args["N"], 4096) 

115 # return builtins.min(triton.next_power_of_2(args["N"]), 8192) 

116 

117 

118def softmax_backward_kernel_inner_heur_one_tile_per_cta(args): 

119 return args["TILE_N"] >= args["N"] 

120 

121 

122@libentry() 

123# @triton.autotune( 

124# configs=runtime.get_tuned_config("softmax_inner"), 

125# key=["M", "N"], 

126# ) 

127# @triton.heuristics( 

128# values=runtime.get_heuristic_config("softmax_backward_inner"), 

129# ) 

130@triton.heuristics( 

131 values={ 

132 "TILE_M": softmax_backward_kernel_inner_heur_tile_m, 

133 "TILE_N": softmax_backward_kernel_inner_heru_tile_n, 

134 "ONE_TILE_PER_CTA": softmax_backward_kernel_inner_heur_one_tile_per_cta, 

135 }, 

136) 

137@triton.jit 

138def softmax_backward_kernel_inner( 

139 out_ptr, 

140 out_grad_ptr, 

141 in_grad_ptr, 

142 M, 

143 N, 

144 TILE_M: tl.constexpr, 

145 TILE_N: tl.constexpr, 

146 ONE_TILE_PER_CTA: tl.constexpr, 

147): 

148 pid_m = tle.program_id(0) 

149 m_offsets = pid_m * TILE_M + tl.arange(0, TILE_M) 

150 if ONE_TILE_PER_CTA: 

151 n_offsets = tl.arange(0, TILE_N) 

152 offsets = m_offsets[:, None] * N + n_offsets 

153 mask = (m_offsets[:, None] < M) & (n_offsets < N) 

154 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float64) 

155 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float64) 

156 scale = tl.sum(out_tile * out_grad_tile, 1) 

157 in_grad_tile = out_tile * (out_grad_tile - scale[:, None]) 

158 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask) 

159 else: 

160 scale = tl.zeros([TILE_M, TILE_N], dtype=tl.float64) 

161 

162 n_offsets = tl.arange(0, TILE_N) 

163 offsets = m_offsets[:, None] * N + n_offsets 

164 for _ in range(0, N, TILE_N): 

165 mask = (m_offsets[:, None] < M) & (n_offsets < N) 

166 out_tile = tl.load( 

167 out_ptr + offsets, mask=mask, eviction_policy="evict_last" 

168 ).to(tl.float64) 

169 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float64) 

170 scale += out_tile * out_grad_tile 

171 n_offsets += TILE_N 

172 offsets += TILE_N 

173 scale = tl.sum(scale, 1) # (TILE_M,) 

174 

175 n_offsets = tl.arange(0, TILE_N) 

176 offsets = m_offsets[:, None] * N + n_offsets 

177 for _ in range(0, N, TILE_N): 

178 mask = (m_offsets[:, None] < M) & (n_offsets < N) 

179 out_tile = tl.load( 

180 out_ptr + offsets, mask=mask, eviction_policy="evict_first" 

181 ) 

182 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float64) 

183 in_grad_tile = out_tile * (out_grad_tile - scale[:, None]).to(tl.float64) 

184 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask) 

185 n_offsets += TILE_N 

186 offsets += TILE_N 

187 

188 

189def softmax(self, dim, half_to_float=False): 

190 logger.debug("GEMS SOFTMAX") 

191 

192 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim" 

193 

194 # special handling for dim = 0 and empty tensor 

195 if self.numel() == 0: 

196 out_shape = list(self.shape) 

197 out = torch.empty(out_shape, dtype=self.dtype, device=self.device) 

198 zero_(out) 

199 return out 

200 

201 dim = dim % self.ndim 

202 M = 1 

203 N = self.shape[dim] 

204 for i in range(dim): 

205 M *= self.shape[i] # pre_dim 

206 self = self.contiguous() 

207 if half_to_float: 

208 dtype = torch.float32 

209 else: 

210 dtype = self.dtype 

211 out = torch.empty_like(self, dtype=dtype) 

212 K = self.numel() // M // N # post_dim 

213 

214 with torch_device_fn.device(self.device): 

215 if K > 1: 

216 # grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1) 

217 # 重新排列输入数据为 [M, K, N] 

218 inp_view = self.view(M, N, K).transpose(1, 2).contiguous() 

219 # 合并 M 和 K 维为 M' = M * K 

220 inp_reshaped = inp_view.view(M * K, N) 

221 if out.ndim == 3: 

222 m, n, k = out.shape 

223 elif out.ndim == 2: 

224 m, n = out.shape 

225 origin_dim = out.ndim 

226 

227 # 分配输出的视图 

228 out_view = out.view(M, N, K).transpose(1, 2).contiguous() 

229 out_reshaped = out_view.view(M * K, N) 

230 

231 grid = lambda meta: (M * K, 1, 1) 

232 

233 # 调用 Triton 前向内核 

234 softmax_kernel_inner[grid]( 

235 out_reshaped, 

236 inp_reshaped, 

237 M * K, 

238 N, 

239 buffer_size_limit=2048, 

240 is_use_mask_zero=True, 

241 ) 

242 

243 # 将输出恢复到原始布局 

244 # out_view.copy_(out_reshaped.view(M, K, N).transpose(1, 2)) 

245 if M == 1 and origin_dim == 2: 

246 out = out_reshaped.view(K, N).transpose(0, 1) 

247 elif M == 1 and origin_dim == 3: 

248 out = out_reshaped.transpose(0, 1).view(m, n, k) 

249 else: 

250 out = out_reshaped.view(m, k, n).transpose(1, 2) 

251 else: 

252 grid = (M, 1, 1) 

253 softmax_kernel_inner[grid]( 

254 out, 

255 self, 

256 M, 

257 N, 

258 buffer_size_limit=2048, 

259 isCloseVectorization=True, 

260 is_use_mask_zero=True, 

261 ) 

262 return out 

263 

264 

265def softmax_backward(grad_output, output, dim, input_dtype): 

266 logger.debug("GEMS SOFTMAX VJP") 

267 

268 assert dim >= -output.ndim and dim < output.ndim, "Invalid dim" 

269 dim = dim % output.ndim 

270 M = 1 

271 N = output.shape[dim] 

272 for i in range(dim): 

273 M *= output.shape[i] 

274 

275 grad_output = grad_output.contiguous() 

276 output = output.contiguous() 

277 in_grad = torch.empty_like(output, dtype=torch.float64) 

278 K = output.numel() // M // N 

279 

280 with torch_device_fn.device(in_grad.device): 

281 if K > 1: 

282 # how to use softmax_backward_kernel_inner? 

283 # some transpose and continuous 

284 out_grad_view = grad_output.view(M, N, K).transpose(1, 2).contiguous() 

285 out_view = output.view(M, N, K).transpose(1, 2).contiguous() 

286 # # 合并 M 和 K 维为 M' = M * K 

287 out_grad_reshaped = out_grad_view.view(M * K, N) 

288 out_reshaped = out_view.view(M * K, N) 

289 # 分配输入梯度的视图 

290 in_grad_view = in_grad.view(M, N, K).transpose(1, 2).contiguous() 

291 in_grad_reshaped = in_grad_view.view(M * K, N) 

292 

293 grid = lambda meta: (12, 1, 1) 

294 

295 # 调用 Triton 反向内核 

296 softmax_backward_kernel_inner[grid]( 

297 out_reshaped, 

298 out_grad_reshaped, 

299 in_grad_reshaped, 

300 M * K, 

301 N, 

302 buffer_size_limit=2048, 

303 isCloseUnrollControl=True, 

304 ) 

305 # 将输入梯度恢复到原始布局 

306 # in_grad_view.copy_(in_grad_reshaped.view(M, K, N).transpose(1, 2)) 

307 origin_dim = output.ndim 

308 if output.ndim == 3: 

309 m, n, k = output.shape 

310 elif output.ndim == 2: 

311 m, n = output.shape 

312 if M == 1 and origin_dim == 2: 

313 in_grad = in_grad_reshaped.view(K, N).transpose(0, 1) 

314 elif M == 1 and origin_dim == 3: 

315 in_grad = in_grad_reshaped.transpose(0, 1).view(m, n, k) 

316 else: 

317 in_grad = in_grad_reshaped.view(m, k, n).transpose(1, 2) 

318 else: 

319 grid = lambda meta: (12, 1, 1) 

320 

321 softmax_backward_kernel_inner[grid]( 

322 output, 

323 grad_output, 

324 in_grad, 

325 M, 

326 N, 

327 buffer_size_limit=2048, 

328 isCloseUnrollControl=True, 

329 ) 

330 return in_grad.to(input_dtype)