Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/softmax.py: 0%
186 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-30 03:43 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-30 03:43 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.ops.zeros import zero_
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as tle
13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
16@triton.jit
17def next_multiple_of(a, b):
18 # the smallest x>=a that x%b ==0
19 return tl.cdiv(a, b) * b
22@triton.jit
23def prev_multiple_of(a, b):
24 # the largest x<a that x%b ==0
25 return tl.cdiv(a, b) * b - b
28@libentry()
29@triton.heuristics(runtime.get_heuristic_config("softmax_inner"))
30@triton.jit
31def softmax_kernel_inner(
32 output_ptr,
33 input_ptr,
34 M,
35 N,
36 TILE_N: tl.constexpr,
37 ONE_TILE_PER_CTA: tl.constexpr,
38):
39 pid_m = tle.program_id(0)
40 if ONE_TILE_PER_CTA:
41 n_offsets = tl.arange(0, TILE_N)
42 offset = pid_m * N + n_offsets
43 input_ptrs = input_ptr + offset
44 mask = n_offsets < N
45 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(
46 output_ptr.dtype.element_ty
47 )
48 m = tl.max(inp, 0)
49 e = tl.exp(inp - m)
50 z = tl.sum(e, 0)
51 out = e / z
52 output_ptrs = output_ptr + offset
53 tl.store(output_ptrs, out, mask=mask)
54 else:
55 m = tl.full([TILE_N], value=float("-inf"), dtype=tl.float32)
56 z = tl.full([TILE_N], value=0.0, dtype=tl.float32)
57 input_ptr += pid_m * N
58 output_ptr += pid_m * N
60 previous_multiple = prev_multiple_of(N, TILE_N)
61 for start_n in range(0, previous_multiple, TILE_N):
62 n_offsets = start_n + tl.arange(0, TILE_N)
63 inp = tl.load(input_ptr + n_offsets)
64 m_new = tl.maximum(m, inp)
65 # it is possible that there are -inf's in the input
66 all_neg_inf = m_new == float("-inf")
67 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new))
68 m = m_new
69 # specialize the last iteration
70 for start_n in range(previous_multiple, N, TILE_N):
71 n_offsets = start_n + tl.arange(0, TILE_N)
72 mask = n_offsets < N
73 inp = tl.load(input_ptr + n_offsets, mask=mask, other=-float("inf"))
74 m_new = tl.maximum(m, inp)
75 all_neg_inf = m_new == float("-inf")
76 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new))
77 m = m_new
79 m_reduced = tl.max(m, 0)
80 z = tl.sum(z * tl.exp(m - m_reduced), 0)
81 m = m_reduced
83 previous_multiple = prev_multiple_of(N, TILE_N)
84 # specialize the first iteration
85 for start_n in range(0, TILE_N, TILE_N):
86 n_offsets = (previous_multiple - start_n) + tl.arange(0, TILE_N)
87 mask = n_offsets < N
88 inp = tl.load(
89 input_ptr + n_offsets,
90 mask=mask,
91 other=-float("inf"),
92 eviction_policy="evict_first",
93 )
94 o = tl.exp(inp - m) / z
95 tl.store(output_ptr + n_offsets, o, mask=mask)
96 for start_n in range(TILE_N, N, TILE_N):
97 n_offsets = (previous_multiple - start_n) + tl.arange(0, TILE_N)
98 inp = tl.load(input_ptr + n_offsets, eviction_policy="evict_first")
99 o = tl.exp(inp - m) / z
100 tl.store(output_ptr + n_offsets, o)
103# ------------------------ backward -------------------------------
106def softmax_backward_kernel_inner_heur_tile_m(args):
107 return triton.cdiv(args["M"], 12) # cluster_num
108 # return triton.next_power_of_2(triton.cdiv(args["M"], 12))
111def softmax_backward_kernel_inner_heru_tile_n(args):
112 import builtins
114 return builtins.min(args["N"], 4096)
115 # return builtins.min(triton.next_power_of_2(args["N"]), 8192)
118def softmax_backward_kernel_inner_heur_one_tile_per_cta(args):
119 return args["TILE_N"] >= args["N"]
122@libentry()
123# @triton.autotune(
124# configs=runtime.get_tuned_config("softmax_inner"),
125# key=["M", "N"],
126# )
127# @triton.heuristics(
128# values=runtime.get_heuristic_config("softmax_backward_inner"),
129# )
130@triton.heuristics(
131 values={
132 "TILE_M": softmax_backward_kernel_inner_heur_tile_m,
133 "TILE_N": softmax_backward_kernel_inner_heru_tile_n,
134 "ONE_TILE_PER_CTA": softmax_backward_kernel_inner_heur_one_tile_per_cta,
135 },
136)
137@triton.jit
138def softmax_backward_kernel_inner(
139 out_ptr,
140 out_grad_ptr,
141 in_grad_ptr,
142 M,
143 N,
144 TILE_M: tl.constexpr,
145 TILE_N: tl.constexpr,
146 ONE_TILE_PER_CTA: tl.constexpr,
147):
148 pid_m = tle.program_id(0)
149 m_offsets = pid_m * TILE_M + tl.arange(0, TILE_M)
150 if ONE_TILE_PER_CTA:
151 n_offsets = tl.arange(0, TILE_N)
152 offsets = m_offsets[:, None] * N + n_offsets
153 mask = (m_offsets[:, None] < M) & (n_offsets < N)
154 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float64)
155 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float64)
156 scale = tl.sum(out_tile * out_grad_tile, 1)
157 in_grad_tile = out_tile * (out_grad_tile - scale[:, None])
158 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask)
159 else:
160 scale = tl.zeros([TILE_M, TILE_N], dtype=tl.float64)
162 n_offsets = tl.arange(0, TILE_N)
163 offsets = m_offsets[:, None] * N + n_offsets
164 for _ in range(0, N, TILE_N):
165 mask = (m_offsets[:, None] < M) & (n_offsets < N)
166 out_tile = tl.load(
167 out_ptr + offsets, mask=mask, eviction_policy="evict_last"
168 ).to(tl.float64)
169 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float64)
170 scale += out_tile * out_grad_tile
171 n_offsets += TILE_N
172 offsets += TILE_N
173 scale = tl.sum(scale, 1) # (TILE_M,)
175 n_offsets = tl.arange(0, TILE_N)
176 offsets = m_offsets[:, None] * N + n_offsets
177 for _ in range(0, N, TILE_N):
178 mask = (m_offsets[:, None] < M) & (n_offsets < N)
179 out_tile = tl.load(
180 out_ptr + offsets, mask=mask, eviction_policy="evict_first"
181 )
182 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float64)
183 in_grad_tile = out_tile * (out_grad_tile - scale[:, None]).to(tl.float64)
184 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask)
185 n_offsets += TILE_N
186 offsets += TILE_N
189def softmax(self, dim, half_to_float=False):
190 logger.debug("GEMS SOFTMAX")
192 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim"
194 # special handling for dim = 0 and empty tensor
195 if self.numel() == 0:
196 out_shape = list(self.shape)
197 out = torch.empty(out_shape, dtype=self.dtype, device=self.device)
198 zero_(out)
199 return out
201 dim = dim % self.ndim
202 M = 1
203 N = self.shape[dim]
204 for i in range(dim):
205 M *= self.shape[i] # pre_dim
206 self = self.contiguous()
207 if half_to_float:
208 dtype = torch.float32
209 else:
210 dtype = self.dtype
211 out = torch.empty_like(self, dtype=dtype)
212 K = self.numel() // M // N # post_dim
214 with torch_device_fn.device(self.device):
215 if K > 1:
216 # grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1)
217 # 重新排列输入数据为 [M, K, N]
218 inp_view = self.view(M, N, K).transpose(1, 2).contiguous()
219 # 合并 M 和 K 维为 M' = M * K
220 inp_reshaped = inp_view.view(M * K, N)
221 if out.ndim == 3:
222 m, n, k = out.shape
223 elif out.ndim == 2:
224 m, n = out.shape
225 origin_dim = out.ndim
227 # 分配输出的视图
228 out_view = out.view(M, N, K).transpose(1, 2).contiguous()
229 out_reshaped = out_view.view(M * K, N)
231 grid = lambda meta: (M * K, 1, 1)
233 # 调用 Triton 前向内核
234 softmax_kernel_inner[grid](
235 out_reshaped,
236 inp_reshaped,
237 M * K,
238 N,
239 buffer_size_limit=2048,
240 is_use_mask_zero=True,
241 )
243 # 将输出恢复到原始布局
244 # out_view.copy_(out_reshaped.view(M, K, N).transpose(1, 2))
245 if M == 1 and origin_dim == 2:
246 out = out_reshaped.view(K, N).transpose(0, 1)
247 elif M == 1 and origin_dim == 3:
248 out = out_reshaped.transpose(0, 1).view(m, n, k)
249 else:
250 out = out_reshaped.view(m, k, n).transpose(1, 2)
251 else:
252 grid = (M, 1, 1)
253 softmax_kernel_inner[grid](
254 out,
255 self,
256 M,
257 N,
258 buffer_size_limit=2048,
259 isCloseVectorization=True,
260 is_use_mask_zero=True,
261 )
262 return out
265def softmax_backward(grad_output, output, dim, input_dtype):
266 logger.debug("GEMS SOFTMAX VJP")
268 assert dim >= -output.ndim and dim < output.ndim, "Invalid dim"
269 dim = dim % output.ndim
270 M = 1
271 N = output.shape[dim]
272 for i in range(dim):
273 M *= output.shape[i]
275 grad_output = grad_output.contiguous()
276 output = output.contiguous()
277 in_grad = torch.empty_like(output, dtype=torch.float64)
278 K = output.numel() // M // N
280 with torch_device_fn.device(in_grad.device):
281 if K > 1:
282 # how to use softmax_backward_kernel_inner?
283 # some transpose and continuous
284 out_grad_view = grad_output.view(M, N, K).transpose(1, 2).contiguous()
285 out_view = output.view(M, N, K).transpose(1, 2).contiguous()
286 # # 合并 M 和 K 维为 M' = M * K
287 out_grad_reshaped = out_grad_view.view(M * K, N)
288 out_reshaped = out_view.view(M * K, N)
289 # 分配输入梯度的视图
290 in_grad_view = in_grad.view(M, N, K).transpose(1, 2).contiguous()
291 in_grad_reshaped = in_grad_view.view(M * K, N)
293 grid = lambda meta: (12, 1, 1)
295 # 调用 Triton 反向内核
296 softmax_backward_kernel_inner[grid](
297 out_reshaped,
298 out_grad_reshaped,
299 in_grad_reshaped,
300 M * K,
301 N,
302 buffer_size_limit=2048,
303 isCloseUnrollControl=True,
304 )
305 # 将输入梯度恢复到原始布局
306 # in_grad_view.copy_(in_grad_reshaped.view(M, K, N).transpose(1, 2))
307 origin_dim = output.ndim
308 if output.ndim == 3:
309 m, n, k = output.shape
310 elif output.ndim == 2:
311 m, n = output.shape
312 if M == 1 and origin_dim == 2:
313 in_grad = in_grad_reshaped.view(K, N).transpose(0, 1)
314 elif M == 1 and origin_dim == 3:
315 in_grad = in_grad_reshaped.transpose(0, 1).view(m, n, k)
316 else:
317 in_grad = in_grad_reshaped.view(m, k, n).transpose(1, 2)
318 else:
319 grid = lambda meta: (12, 1, 1)
321 softmax_backward_kernel_inner[grid](
322 output,
323 grad_output,
324 in_grad,
325 M,
326 N,
327 buffer_size_limit=2048,
328 isCloseUnrollControl=True,
329 )
330 return in_grad.to(input_dtype)