Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/log_softmax.py: 0%
105 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
1import builtins
2import logging
4import torch
5import triton
6import triton.language as tl
8# from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as tle
13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
16def heur_block_n(args):
17 if args["N"] > 8192:
18 return 64
19 return builtins.min(args["N"], 8192)
22def heur_block_m(args):
23 return triton.next_power_of_2(triton.cdiv(args["M"], 12))
26@libentry()
27# @triton.autotune(configs=runtime.get_triton_config("log_softmax"), key=["M", "N"])
28@triton.heuristics(
29 {
30 "BLOCK_M": heur_block_m,
31 "BLOCK_N": heur_block_n,
32 }
33)
34@triton.jit
35def log_softmax_kernel(
36 output_ptr,
37 input_ptr,
38 M,
39 N,
40 K,
41 BLOCK_M: tl.constexpr,
42 BLOCK_N: tl.constexpr,
43):
44 pid_m = tle.program_id(0)
45 pid_k = tle.program_id(1)
46 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
48 # TODO(chenfeiyu): consider float64 add add a utility function to get accumulator type
49 m = tl.full([BLOCK_M, BLOCK_N], value=float("-inf"), dtype=tl.float32)
50 z = tl.full([BLOCK_M, BLOCK_N], value=0.0, dtype=tl.float32)
51 for start_n in range(0, N, BLOCK_N):
52 n_offset = start_n + tl.arange(0, BLOCK_N)
53 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
54 mask = m_offset[:, None] < M and n_offset[None, :] < N
55 input_ptrs = input_ptr + offset
56 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32)
57 m_new = tl.maximum(inp, m)
58 all_neg_inf = m_new == float("-inf")
59 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new))
60 m = m_new
62 m_reduced = tl.max(m, 1)
63 z = tl.sum(z * tl.exp(m - m_reduced[:, None]), 1)
64 m = m_reduced
66 for start_n in range(0, N, BLOCK_N):
67 n_offset = start_n + tl.arange(0, BLOCK_N)
68 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
69 mask = m_offset[:, None] < M and n_offset[None, :] < N
70 input_ptrs = input_ptr + offset
71 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32)
72 o = inp - m[:, None] - tl.log(z[:, None])
73 tl.store(output_ptr + offset, o, mask=mask)
76@libentry()
77# @triton.autotune(configs=runtime.get_tuned_config("log_softmax"), key=["M", "N"])
78@triton.heuristics(
79 {
80 "BLOCK_M": heur_block_m,
81 "BLOCK_N": heur_block_n,
82 }
83)
84@triton.jit
85def log_softmax_backward_kernel(
86 out_ptr,
87 out_grad_ptr,
88 in_grad_ptr,
89 M,
90 N,
91 K,
92 BLOCK_M: tl.constexpr,
93 BLOCK_N: tl.constexpr,
94):
95 pid_m = tle.program_id(0)
96 pid_k = tle.program_id(1)
97 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
99 scale = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
100 for start_n in range(0, N, BLOCK_N):
101 n_offset = start_n + tl.arange(0, BLOCK_N)
102 offsets = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
103 mask = m_offset[:, None] < M and n_offset[None, :] < N
104 out_grad_ptrs = out_grad_ptr + offsets
105 out_grad = tl.load(out_grad_ptrs, mask=mask).to(tl.float32)
106 scale += out_grad
107 scale = tl.sum(scale, 1)
109 for start_n in range(0, N, BLOCK_N):
110 n_offset = start_n + tl.arange(0, BLOCK_N)
111 offsets = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
112 mask = m_offset[:, None] < M and n_offset[None, :] < N
113 out_ptrs = out_ptr + offsets
114 out = tl.load(out_ptrs, mask=mask).to(tl.float32)
115 out_grad_ptrs = out_grad_ptr + offsets
116 out_grad = tl.load(out_grad_ptrs, mask=mask).to(tl.float32)
117 in_grad = out_grad - tl.exp(out) * scale[:, None]
118 in_grad_ptrs = in_grad_ptr + offsets
119 tl.store(in_grad_ptrs, in_grad, mask=mask)
122def log_softmax(self, dim, half_to_float=False):
123 logger.debug("GEMS LOG_SOFTMAX")
125 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim"
126 dim = dim % self.ndim
127 M = 1
128 N = self.shape[dim]
129 for i in range(dim):
130 M *= self.shape[i]
131 inp = self.contiguous()
132 if half_to_float:
133 dtype = torch.float32
134 else:
135 dtype = self.dtype
136 out = torch.empty_like(inp, dtype=dtype)
137 K = inp.numel() // M // N
139 grid = lambda meta: (
140 triton.cdiv(M, meta["BLOCK_M"]),
141 K,
142 )
143 with torch_device_fn.device(inp.device):
144 log_softmax_kernel[grid](
145 out,
146 inp,
147 M,
148 N,
149 K,
150 isCloseCoreTiling=True,
151 num_warps=8,
152 )
153 return out
156def log_softmax_backward(grad_output, output, dim, input_dtype):
157 logger.debug("GEMS LOG_SOFTMAX VJP")
159 assert dim >= -output.ndim and dim < output.ndim, "Invalid dim"
160 dim = dim % output.ndim
161 M = 1
162 N = output.shape[dim]
163 for i in range(dim):
164 M *= output.shape[i]
166 grad_output = grad_output.contiguous()
167 in_grad = torch.empty_like(output, dtype=input_dtype)
168 K = output.numel() // M // N
170 grid = lambda meta: (
171 triton.cdiv(M, meta["BLOCK_M"]),
172 K,
173 )
174 with torch_device_fn.device(in_grad.device):
175 log_softmax_backward_kernel[grid](
176 output,
177 grad_output,
178 in_grad,
179 M,
180 N,
181 K,
182 isCloseCoreTiling=True,
183 )
184 return in_grad