Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/log_softmax.py: 0%

105 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-15 02:11 +0800

1import builtins 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8# from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

14 

15 

16def heur_block_n(args): 

17 if args["N"] > 8192: 

18 return 64 

19 return builtins.min(args["N"], 8192) 

20 

21 

22def heur_block_m(args): 

23 return triton.next_power_of_2(triton.cdiv(args["M"], 12)) 

24 

25 

26@libentry() 

27# @triton.autotune(configs=runtime.get_triton_config("log_softmax"), key=["M", "N"]) 

28@triton.heuristics( 

29 { 

30 "BLOCK_M": heur_block_m, 

31 "BLOCK_N": heur_block_n, 

32 } 

33) 

34@triton.jit 

35def log_softmax_kernel( 

36 output_ptr, 

37 input_ptr, 

38 M, 

39 N, 

40 K, 

41 BLOCK_M: tl.constexpr, 

42 BLOCK_N: tl.constexpr, 

43): 

44 pid_m = tle.program_id(0) 

45 pid_k = tle.program_id(1) 

46 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

47 

48 # TODO(chenfeiyu): consider float64 add add a utility function to get accumulator type 

49 m = tl.full([BLOCK_M, BLOCK_N], value=float("-inf"), dtype=tl.float32) 

50 z = tl.full([BLOCK_M, BLOCK_N], value=0.0, dtype=tl.float32) 

51 for start_n in range(0, N, BLOCK_N): 

52 n_offset = start_n + tl.arange(0, BLOCK_N) 

53 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k 

54 mask = m_offset[:, None] < M and n_offset[None, :] < N 

55 input_ptrs = input_ptr + offset 

56 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32) 

57 m_new = tl.maximum(inp, m) 

58 all_neg_inf = m_new == float("-inf") 

59 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new)) 

60 m = m_new 

61 

62 m_reduced = tl.max(m, 1) 

63 z = tl.sum(z * tl.exp(m - m_reduced[:, None]), 1) 

64 m = m_reduced 

65 

66 for start_n in range(0, N, BLOCK_N): 

67 n_offset = start_n + tl.arange(0, BLOCK_N) 

68 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k 

69 mask = m_offset[:, None] < M and n_offset[None, :] < N 

70 input_ptrs = input_ptr + offset 

71 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32) 

72 o = inp - m[:, None] - tl.log(z[:, None]) 

73 tl.store(output_ptr + offset, o, mask=mask) 

74 

75 

76@libentry() 

77# @triton.autotune(configs=runtime.get_tuned_config("log_softmax"), key=["M", "N"]) 

78@triton.heuristics( 

79 { 

80 "BLOCK_M": heur_block_m, 

81 "BLOCK_N": heur_block_n, 

82 } 

83) 

84@triton.jit 

85def log_softmax_backward_kernel( 

86 out_ptr, 

87 out_grad_ptr, 

88 in_grad_ptr, 

89 M, 

90 N, 

91 K, 

92 BLOCK_M: tl.constexpr, 

93 BLOCK_N: tl.constexpr, 

94): 

95 pid_m = tle.program_id(0) 

96 pid_k = tle.program_id(1) 

97 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

98 

99 scale = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

100 for start_n in range(0, N, BLOCK_N): 

101 n_offset = start_n + tl.arange(0, BLOCK_N) 

102 offsets = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k 

103 mask = m_offset[:, None] < M and n_offset[None, :] < N 

104 out_grad_ptrs = out_grad_ptr + offsets 

105 out_grad = tl.load(out_grad_ptrs, mask=mask).to(tl.float32) 

106 scale += out_grad 

107 scale = tl.sum(scale, 1) 

108 

109 for start_n in range(0, N, BLOCK_N): 

110 n_offset = start_n + tl.arange(0, BLOCK_N) 

111 offsets = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k 

112 mask = m_offset[:, None] < M and n_offset[None, :] < N 

113 out_ptrs = out_ptr + offsets 

114 out = tl.load(out_ptrs, mask=mask).to(tl.float32) 

115 out_grad_ptrs = out_grad_ptr + offsets 

116 out_grad = tl.load(out_grad_ptrs, mask=mask).to(tl.float32) 

117 in_grad = out_grad - tl.exp(out) * scale[:, None] 

118 in_grad_ptrs = in_grad_ptr + offsets 

119 tl.store(in_grad_ptrs, in_grad, mask=mask) 

120 

121 

122def log_softmax(self, dim, half_to_float=False): 

123 logger.debug("GEMS LOG_SOFTMAX") 

124 

125 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim" 

126 dim = dim % self.ndim 

127 M = 1 

128 N = self.shape[dim] 

129 for i in range(dim): 

130 M *= self.shape[i] 

131 inp = self.contiguous() 

132 if half_to_float: 

133 dtype = torch.float32 

134 else: 

135 dtype = self.dtype 

136 out = torch.empty_like(inp, dtype=dtype) 

137 K = inp.numel() // M // N 

138 

139 grid = lambda meta: ( 

140 triton.cdiv(M, meta["BLOCK_M"]), 

141 K, 

142 ) 

143 with torch_device_fn.device(inp.device): 

144 log_softmax_kernel[grid]( 

145 out, 

146 inp, 

147 M, 

148 N, 

149 K, 

150 isCloseCoreTiling=True, 

151 num_warps=8, 

152 ) 

153 return out 

154 

155 

156def log_softmax_backward(grad_output, output, dim, input_dtype): 

157 logger.debug("GEMS LOG_SOFTMAX VJP") 

158 

159 assert dim >= -output.ndim and dim < output.ndim, "Invalid dim" 

160 dim = dim % output.ndim 

161 M = 1 

162 N = output.shape[dim] 

163 for i in range(dim): 

164 M *= output.shape[i] 

165 

166 grad_output = grad_output.contiguous() 

167 in_grad = torch.empty_like(output, dtype=input_dtype) 

168 K = output.numel() // M // N 

169 

170 grid = lambda meta: ( 

171 triton.cdiv(M, meta["BLOCK_M"]), 

172 K, 

173 ) 

174 with torch_device_fn.device(in_grad.device): 

175 log_softmax_backward_kernel[grid]( 

176 output, 

177 grad_output, 

178 in_grad, 

179 M, 

180 N, 

181 K, 

182 isCloseCoreTiling=True, 

183 ) 

184 return in_grad