Coverage for src/flag_gems/runtime/backend/_ascend/ops/log_softmax.py: 0%
99 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
15@libentry()
16@triton.autotune(configs=runtime.get_tuned_config("log_softmax"), key=["M", "N"])
17@triton.jit
18def log_softmax_kernel(
19 output_ptr,
20 input_ptr,
21 M,
22 N,
23 K,
24 BLOCK_M: tl.constexpr,
25 BLOCK_N: tl.constexpr,
26):
27 pid_m = tle.program_id(0)
28 pid_k = tle.program_id(1)
29 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
31 # TODO(chenfeiyu): consider float64 add add a utility function to get accumulator type
32 m = tl.full([BLOCK_M, BLOCK_N], value=float("-inf"), dtype=tl.float32)
33 z = tl.full([BLOCK_M, BLOCK_N], value=0.0, dtype=tl.float32)
34 for start_n in range(0, N, BLOCK_N):
35 n_offset = start_n + tl.arange(0, BLOCK_N)
36 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
37 mask = m_offset[:, None] < M and n_offset[None, :] < N
38 input_ptrs = input_ptr + offset
39 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32)
40 m_new = tl.maximum(inp, m)
41 all_neg_inf = m_new == float("-inf")
42 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new))
43 m = m_new
45 m_reduced = tl.max(m, 1)
46 z = tl.sum(z * tl.exp(m - m_reduced[:, None]), 1)
47 m = m_reduced
49 for start_n in range(0, N, BLOCK_N):
50 n_offset = start_n + tl.arange(0, BLOCK_N)
51 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
52 mask = m_offset[:, None] < M and n_offset[None, :] < N
53 input_ptrs = input_ptr + offset
54 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32)
55 o = inp - m[:, None] - tl.log(z[:, None])
56 tl.store(output_ptr + offset, o, mask=mask)
59@libentry()
60@triton.autotune(configs=runtime.get_tuned_config("log_softmax"), key=["M", "N"])
61@triton.jit
62def log_softmax_backward_kernel(
63 out_ptr,
64 out_grad_ptr,
65 in_grad_ptr,
66 M,
67 N,
68 K,
69 BLOCK_M: tl.constexpr,
70 BLOCK_N: tl.constexpr,
71):
72 pid_m = tle.program_id(0)
73 pid_k = tle.program_id(1)
74 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
76 scale = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
77 for start_n in range(0, N, BLOCK_N):
78 n_offset = start_n + tl.arange(0, BLOCK_N)
79 offsets = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
80 mask = m_offset[:, None] < M and n_offset[None, :] < N
81 out_grad_ptrs = out_grad_ptr + offsets
82 out_grad = tl.load(out_grad_ptrs, mask=mask).to(tl.float32)
83 scale += out_grad
84 scale = tl.sum(scale, 1)
86 for start_n in range(0, N, BLOCK_N):
87 n_offset = start_n + tl.arange(0, BLOCK_N)
88 offsets = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
89 mask = m_offset[:, None] < M and n_offset[None, :] < N
90 out_ptrs = out_ptr + offsets
91 out = tl.load(out_ptrs, mask=mask).to(tl.float32)
92 out_grad_ptrs = out_grad_ptr + offsets
93 out_grad = tl.load(out_grad_ptrs, mask=mask).to(tl.float32)
94 in_grad = out_grad - tl.exp(out) * scale[:, None]
95 in_grad_ptrs = in_grad_ptr + offsets
96 tl.store(in_grad_ptrs, in_grad, mask=mask)
99def log_softmax(self, dim, half_to_float=False):
100 logger.debug("GEMS_ASCEND LOG_SOFTMAX")
102 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim"
103 dim = dim % self.ndim
104 M = 1
105 N = self.shape[dim]
106 for i in range(dim):
107 M *= self.shape[i]
108 inp = self.contiguous()
109 if half_to_float:
110 dtype = torch.float32
111 else:
112 dtype = self.dtype
113 out = torch.empty_like(inp, dtype=dtype)
114 K = inp.numel() // M // N
116 grid = lambda meta: (
117 triton.cdiv(M, meta["BLOCK_M"]),
118 K,
119 )
120 with torch_device_fn.device(inp.device):
121 log_softmax_kernel[grid](
122 out,
123 inp,
124 M,
125 N,
126 K,
127 )
128 return out
131def log_softmax_backward(grad_output, output, dim, input_dtype):
132 logger.debug("GEMS_ASCEND LOG_SOFTMAX VJP")
134 assert dim >= -output.ndim and dim < output.ndim, "Invalid dim"
135 dim = dim % output.ndim
136 M = 1
137 N = output.shape[dim]
138 for i in range(dim):
139 M *= output.shape[i]
141 grad_output = grad_output.contiguous()
142 in_grad = torch.empty_like(output, dtype=input_dtype)
143 K = output.numel() // M // N
145 grid = lambda meta: (
146 triton.cdiv(M, meta["BLOCK_M"]),
147 K,
148 )
149 with torch_device_fn.device(in_grad.device):
150 log_softmax_backward_kernel[grid](
151 output,
152 grad_output,
153 in_grad,
154 M,
155 N,
156 K,
157 )
158 return in_grad