Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/softmax.py: 0%
180 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15@triton.jit
16def next_multiple_of(a, b):
17 # the smallest x>=a that x%b ==0
18 return tl.cdiv(a, b) * b
21@triton.jit
22def prev_multiple_of(a, b):
23 # the largest x<a that x%b ==0
24 return tl.cdiv(a, b) * b - b
27@libentry()
28@triton.heuristics(runtime.get_heuristic_config("softmax_inner"))
29@triton.jit
30def softmax_kernel_inner(
31 output_ptr,
32 input_ptr,
33 M,
34 N,
35 TILE_N: tl.constexpr,
36 ONE_TILE_PER_CTA: tl.constexpr,
37):
38 pid_m = tle.program_id(0)
39 if ONE_TILE_PER_CTA:
40 n_offsets = tl.arange(0, TILE_N)
41 offset = pid_m * N + n_offsets
42 input_ptrs = input_ptr + offset
43 mask = n_offsets < N
44 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(
45 output_ptr.dtype.element_ty
46 )
47 m = tl.max(inp, 0)
48 e = tl.exp(inp - m)
49 z = tl.sum(e, 0)
50 out = e / z
51 output_ptrs = output_ptr + offset
52 tl.store(output_ptrs, out, mask=mask)
53 else:
54 m = tl.full([TILE_N], value=float("-inf"), dtype=tl.float32)
55 z = tl.full([TILE_N], value=0.0, dtype=tl.float32)
56 input_ptr += pid_m * N
57 output_ptr += pid_m * N
59 previous_multiple = prev_multiple_of(N, TILE_N)
60 for start_n in range(0, previous_multiple, TILE_N):
61 n_offsets = start_n + tl.arange(0, TILE_N)
62 inp = tl.load(input_ptr + n_offsets)
63 m_new = tl.maximum(m, inp)
64 # it is possible that there are -inf's in the input
65 all_neg_inf = m_new == float("-inf")
66 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new))
67 m = m_new
68 # specialize the last iteration
69 for start_n in range(previous_multiple, N, TILE_N):
70 n_offsets = start_n + tl.arange(0, TILE_N)
71 mask = n_offsets < N
72 inp = tl.load(input_ptr + n_offsets, mask=mask, other=-float("inf"))
73 m_new = tl.maximum(m, inp)
74 all_neg_inf = m_new == float("-inf")
75 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new))
76 m = m_new
78 m_reduced = tl.max(m, 0)
79 z = tl.sum(z * tl.exp(m - m_reduced), 0)
80 m = m_reduced
82 previous_multiple = prev_multiple_of(N, TILE_N)
83 # specialize the first iteration
84 for start_n in range(0, TILE_N, TILE_N):
85 n_offsets = (previous_multiple - start_n) + tl.arange(0, TILE_N)
86 mask = n_offsets < N
87 inp = tl.load(
88 input_ptr + n_offsets,
89 mask=mask,
90 other=-float("inf"),
91 eviction_policy="evict_first",
92 )
93 o = tl.exp(inp - m) / z
94 tl.store(output_ptr + n_offsets, o, mask=mask)
95 for start_n in range(TILE_N, N, TILE_N):
96 n_offsets = (previous_multiple - start_n) + tl.arange(0, TILE_N)
97 inp = tl.load(input_ptr + n_offsets, eviction_policy="evict_first")
98 o = tl.exp(inp - m) / z
99 tl.store(output_ptr + n_offsets, o)
102# ------------------------ backward -------------------------------
105def softmax_backward_kernel_inner_heur_tile_m(args):
106 return triton.cdiv(args["M"], 12) # cluster_num
107 # return triton.next_power_of_2(triton.cdiv(args["M"], 12))
110def softmax_backward_kernel_inner_heru_tile_n(args):
111 import builtins
113 return builtins.min(args["N"], 4096)
114 # return builtins.min(triton.next_power_of_2(args["N"]), 8192)
117def softmax_backward_kernel_inner_heur_one_tile_per_cta(args):
118 return args["TILE_N"] >= args["N"]
121@libentry()
122# @triton.autotune(
123# configs=runtime.get_tuned_config("softmax_inner"),
124# key=["M", "N"],
125# )
126# @triton.heuristics(
127# values=runtime.get_heuristic_config("softmax_backward_inner"),
128# )
129@triton.heuristics(
130 values={
131 "TILE_M": softmax_backward_kernel_inner_heur_tile_m,
132 "TILE_N": softmax_backward_kernel_inner_heru_tile_n,
133 "ONE_TILE_PER_CTA": softmax_backward_kernel_inner_heur_one_tile_per_cta,
134 },
135)
136@triton.jit
137def softmax_backward_kernel_inner(
138 out_ptr,
139 out_grad_ptr,
140 in_grad_ptr,
141 M,
142 N,
143 TILE_M: tl.constexpr,
144 TILE_N: tl.constexpr,
145 ONE_TILE_PER_CTA: tl.constexpr,
146):
147 pid_m = tle.program_id(0)
148 m_offsets = pid_m * TILE_M + tl.arange(0, TILE_M)
149 if ONE_TILE_PER_CTA:
150 n_offsets = tl.arange(0, TILE_N)
151 offsets = m_offsets[:, None] * N + n_offsets
152 mask = (m_offsets[:, None] < M) & (n_offsets < N)
153 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float64)
154 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float64)
155 scale = tl.sum(out_tile * out_grad_tile, 1)
156 in_grad_tile = out_tile * (out_grad_tile - scale[:, None])
157 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask)
158 else:
159 scale = tl.zeros([TILE_M, TILE_N], dtype=tl.float64)
161 n_offsets = tl.arange(0, TILE_N)
162 offsets = m_offsets[:, None] * N + n_offsets
163 for _ in range(0, N, TILE_N):
164 mask = (m_offsets[:, None] < M) & (n_offsets < N)
165 out_tile = tl.load(
166 out_ptr + offsets, mask=mask, eviction_policy="evict_last"
167 ).to(tl.float64)
168 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float64)
169 scale += out_tile * out_grad_tile
170 n_offsets += TILE_N
171 offsets += TILE_N
172 scale = tl.sum(scale, 1) # (TILE_M,)
174 n_offsets = tl.arange(0, TILE_N)
175 offsets = m_offsets[:, None] * N + n_offsets
176 for _ in range(0, N, TILE_N):
177 mask = (m_offsets[:, None] < M) & (n_offsets < N)
178 out_tile = tl.load(
179 out_ptr + offsets, mask=mask, eviction_policy="evict_first"
180 )
181 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float64)
182 in_grad_tile = out_tile * (out_grad_tile - scale[:, None]).to(tl.float64)
183 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask)
184 n_offsets += TILE_N
185 offsets += TILE_N
188def softmax(self, dim, half_to_float=False):
189 logger.debug("GEMS SOFTMAX")
191 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim"
192 dim = dim % self.ndim
193 M = 1
194 N = self.shape[dim]
195 for i in range(dim):
196 M *= self.shape[i] # pre_dim
197 self = self.contiguous()
198 if half_to_float:
199 dtype = torch.float32
200 else:
201 dtype = self.dtype
202 out = torch.empty_like(self, dtype=dtype)
203 K = self.numel() // M // N # post_dim
205 with torch_device_fn.device(self.device):
206 if K > 1:
207 # grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1)
208 # 重新排列输入数据为 [M, K, N]
209 inp_view = self.view(M, N, K).transpose(1, 2).contiguous()
210 # 合并 M 和 K 维为 M' = M * K
211 inp_reshaped = inp_view.view(M * K, N)
212 if out.ndim == 3:
213 m, n, k = out.shape
214 elif out.ndim == 2:
215 m, n = out.shape
216 origin_dim = out.ndim
218 # 分配输出的视图
219 out_view = out.view(M, N, K).transpose(1, 2).contiguous()
220 out_reshaped = out_view.view(M * K, N)
222 grid = lambda meta: (M * K, 1, 1)
224 # 调用 Triton 前向内核
225 softmax_kernel_inner[grid](
226 out_reshaped,
227 inp_reshaped,
228 M * K,
229 N,
230 buffer_size_limit=2048,
231 is_use_mask_zero=True,
232 )
234 # 将输出恢复到原始布局
235 # out_view.copy_(out_reshaped.view(M, K, N).transpose(1, 2))
236 if M == 1 and origin_dim == 2:
237 out = out_reshaped.view(K, N).transpose(0, 1)
238 elif M == 1 and origin_dim == 3:
239 out = out_reshaped.transpose(0, 1).view(m, n, k)
240 else:
241 out = out_reshaped.view(m, k, n).transpose(1, 2)
242 else:
243 grid = (M, 1, 1)
244 softmax_kernel_inner[grid](
245 out,
246 self,
247 M,
248 N,
249 buffer_size_limit=2048,
250 isCloseVectorization=True,
251 is_use_mask_zero=True,
252 )
253 return out
256def softmax_backward(grad_output, output, dim, input_dtype):
257 logger.debug("GEMS SOFTMAX VJP")
259 assert dim >= -output.ndim and dim < output.ndim, "Invalid dim"
260 dim = dim % output.ndim
261 M = 1
262 N = output.shape[dim]
263 for i in range(dim):
264 M *= output.shape[i]
266 grad_output = grad_output.contiguous()
267 output = output.contiguous()
268 in_grad = torch.empty_like(output, dtype=torch.float64)
269 K = output.numel() // M // N
271 with torch_device_fn.device(in_grad.device):
272 if K > 1:
273 # how to use softmax_backward_kernel_inner?
274 # some transpose and continuous
275 out_grad_view = grad_output.view(M, N, K).transpose(1, 2).contiguous()
276 out_view = output.view(M, N, K).transpose(1, 2).contiguous()
277 # # 合并 M 和 K 维为 M' = M * K
278 out_grad_reshaped = out_grad_view.view(M * K, N)
279 out_reshaped = out_view.view(M * K, N)
280 # 分配输入梯度的视图
281 in_grad_view = in_grad.view(M, N, K).transpose(1, 2).contiguous()
282 in_grad_reshaped = in_grad_view.view(M * K, N)
284 grid = lambda meta: (12, 1, 1)
286 # 调用 Triton 反向内核
287 softmax_backward_kernel_inner[grid](
288 out_reshaped,
289 out_grad_reshaped,
290 in_grad_reshaped,
291 M * K,
292 N,
293 buffer_size_limit=2048,
294 isCloseUnrollControl=True,
295 )
296 # 将输入梯度恢复到原始布局
297 # in_grad_view.copy_(in_grad_reshaped.view(M, K, N).transpose(1, 2))
298 origin_dim = output.ndim
299 if output.ndim == 3:
300 m, n, k = output.shape
301 elif output.ndim == 2:
302 m, n = output.shape
303 if M == 1 and origin_dim == 2:
304 in_grad = in_grad_reshaped.view(K, N).transpose(0, 1)
305 elif M == 1 and origin_dim == 3:
306 in_grad = in_grad_reshaped.transpose(0, 1).view(m, n, k)
307 else:
308 in_grad = in_grad_reshaped.view(m, k, n).transpose(1, 2)
309 else:
310 grid = lambda meta: (12, 1, 1)
312 softmax_backward_kernel_inner[grid](
313 output,
314 grad_output,
315 in_grad,
316 M,
317 N,
318 buffer_size_limit=2048,
319 isCloseUnrollControl=True,
320 )
321 return in_grad.to(input_dtype)