Coverage for src/flag_gems/runtime/backend/_metax/ops/log_softmax.py: 0%
98 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger("flag_gems." + __name__)
15def heur_block_n(args):
16 return triton.next_power_of_2(args["N"])
19def heur_num_warps(args):
20 if args["N"] <= 1024:
21 return 1
22 elif args["N"] <= 2048:
23 return 4
24 else:
25 return 8
28@libentry()
29@triton.autotune(configs=runtime.get_tuned_config("log_softmax"), key=["M", "N"])
30@triton.heuristics(
31 {
32 "BLOCK_N": heur_block_n,
33 "num_warps": heur_num_warps,
34 }
35)
36@triton.jit
37def log_softmax_kernel(
38 output_ptr,
39 input_ptr,
40 M,
41 N,
42 K,
43 BLOCK_M: tl.constexpr,
44 BLOCK_N: tl.constexpr,
45 USE_K: tl.constexpr,
46):
47 pid_m = tle.program_id(0)
48 pid_k = tle.program_id(1)
49 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
50 n_offset = tl.arange(0, BLOCK_N)
51 offset = m_offset[:, None] * N * K + n_offset[None, :] * K
52 if USE_K:
53 offset += pid_k
54 mask = m_offset[:, None] < M and n_offset[None, :] < N
55 input_ptrs = input_ptr + offset
56 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32)
57 row_minus_max = inp - tl.max(inp, axis=1)[:, None]
58 numerator = tl.exp(row_minus_max)
59 denominator = tl.sum(numerator, axis=1)[:, None]
60 softmax_output = tl.log(numerator / denominator)
61 output_ptrs = output_ptr + offset
62 tl.store(output_ptrs, softmax_output, mask=mask)
65@libentry()
66@triton.autotune(configs=runtime.get_tuned_config("log_softmax"), key=["M", "N"])
67@triton.heuristics(
68 {
69 "BLOCK_N": heur_block_n,
70 "num_warps": heur_num_warps,
71 }
72)
73@triton.jit
74def log_softmax_backward_kernel(
75 out_ptr,
76 out_grad_ptr,
77 in_grad_ptr,
78 M,
79 N,
80 K,
81 BLOCK_M: tl.constexpr,
82 BLOCK_N: tl.constexpr,
83 BLOCK_N_SPLIT: tl.constexpr,
84):
85 pid_m = tle.program_id(0)
86 pid_k = tle.program_id(1)
87 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
88 n_split_offset = tl.arange(0, BLOCK_N_SPLIT)
89 n_offset = tl.arange(0, BLOCK_N)
90 all_offsets = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
91 out_grad_ptrs_all = out_grad_ptr + all_offsets
92 all_mask = m_offset[:, None] < M and n_offset[None, :] < N
93 out_grad_all = tl.load(out_grad_ptrs_all, mask=all_mask).to(tl.float32)
94 scale = tl.sum(out_grad_all, 1)
95 # use for loop to split N dim to reduce register cost
96 for n in range(0, tl.cdiv(BLOCK_N, BLOCK_N_SPLIT)):
97 offsets = (
98 m_offset[:, None] * N * K
99 + n_split_offset[None, :] * K
100 + n * BLOCK_N_SPLIT * K
101 + pid_k
102 )
103 mask = m_offset[:, None] < M and n_split_offset[None, :] + n * BLOCK_N_SPLIT < N
104 out_ptrs = out_ptr + offsets
105 out = tl.load(out_ptrs, mask=mask).to(tl.float32)
106 exp_out = tl.exp(out.to(tl.float32))
107 out_grad_ptrs = out_grad_ptr + offsets
108 out_grad = tl.load(out_grad_ptrs, mask=mask).to(tl.float32)
110 in_grad = out_grad - exp_out * scale[:, None]
111 in_grad_ptrs = in_grad_ptr + offsets
112 tl.store(in_grad_ptrs, in_grad, mask=mask)
115def log_softmax(self, dim, half_to_float=False):
116 logger.debug("METAX GEMS LOG_SOFTMAX")
118 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim"
119 dim = dim % self.ndim
120 M = 1
121 N = self.shape[dim]
122 for i in range(dim):
123 M *= self.shape[i]
124 inp = self.contiguous()
125 if half_to_float:
126 dtype = torch.float32
127 else:
128 dtype = self.dtype
129 out = torch.empty_like(inp, dtype=dtype)
130 K = inp.numel() // M // N
131 USE_K = K != 1
133 grid = lambda meta: (
134 triton.cdiv(M, meta["BLOCK_M"]),
135 K,
136 )
137 with torch_device_fn.device(inp.device):
138 log_softmax_kernel[grid](
139 out,
140 inp,
141 M,
142 N,
143 K,
144 USE_K=USE_K,
145 )
146 return out
149def log_softmax_backward(grad_output, output, dim, input_dtype):
150 logger.debug("METAX GEMS LOG_SOFTMAX VJP")
152 assert dim >= -output.ndim and dim < output.ndim, "Invalid dim"
153 dim = dim % output.ndim
154 M = 1
155 N = output.shape[dim]
156 for i in range(dim):
157 M *= output.shape[i]
159 grad_output = grad_output.contiguous()
160 in_grad = torch.empty_like(output, dtype=input_dtype)
161 K = output.numel() // M // N
163 grid = lambda meta: (
164 triton.cdiv(M, meta["BLOCK_M"]),
165 K,
166 )
167 with torch_device_fn.device(in_grad.device):
168 log_softmax_backward_kernel[grid](
169 output,
170 grad_output,
171 in_grad,
172 M,
173 N,
174 K,
175 BLOCK_N_SPLIT=1024,
176 )
177 return in_grad