Coverage for src/flag_gems/runtime/backend/_ascend/ops/log_softmax.py: 0%

99 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-15 02:11 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

13 

14 

15@libentry() 

16@triton.autotune(configs=runtime.get_tuned_config("log_softmax"), key=["M", "N"]) 

17@triton.jit 

18def log_softmax_kernel( 

19 output_ptr, 

20 input_ptr, 

21 M, 

22 N, 

23 K, 

24 BLOCK_M: tl.constexpr, 

25 BLOCK_N: tl.constexpr, 

26): 

27 pid_m = tle.program_id(0) 

28 pid_k = tle.program_id(1) 

29 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

30 

31 # TODO(chenfeiyu): consider float64 add add a utility function to get accumulator type 

32 m = tl.full([BLOCK_M, BLOCK_N], value=float("-inf"), dtype=tl.float32) 

33 z = tl.full([BLOCK_M, BLOCK_N], value=0.0, dtype=tl.float32) 

34 for start_n in range(0, N, BLOCK_N): 

35 n_offset = start_n + tl.arange(0, BLOCK_N) 

36 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k 

37 mask = m_offset[:, None] < M and n_offset[None, :] < N 

38 input_ptrs = input_ptr + offset 

39 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32) 

40 m_new = tl.maximum(inp, m) 

41 all_neg_inf = m_new == float("-inf") 

42 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new)) 

43 m = m_new 

44 

45 m_reduced = tl.max(m, 1) 

46 z = tl.sum(z * tl.exp(m - m_reduced[:, None]), 1) 

47 m = m_reduced 

48 

49 for start_n in range(0, N, BLOCK_N): 

50 n_offset = start_n + tl.arange(0, BLOCK_N) 

51 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k 

52 mask = m_offset[:, None] < M and n_offset[None, :] < N 

53 input_ptrs = input_ptr + offset 

54 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32) 

55 o = inp - m[:, None] - tl.log(z[:, None]) 

56 tl.store(output_ptr + offset, o, mask=mask) 

57 

58 

59@libentry() 

60@triton.autotune(configs=runtime.get_tuned_config("log_softmax"), key=["M", "N"]) 

61@triton.jit 

62def log_softmax_backward_kernel( 

63 out_ptr, 

64 out_grad_ptr, 

65 in_grad_ptr, 

66 M, 

67 N, 

68 K, 

69 BLOCK_M: tl.constexpr, 

70 BLOCK_N: tl.constexpr, 

71): 

72 pid_m = tle.program_id(0) 

73 pid_k = tle.program_id(1) 

74 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

75 

76 scale = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

77 for start_n in range(0, N, BLOCK_N): 

78 n_offset = start_n + tl.arange(0, BLOCK_N) 

79 offsets = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k 

80 mask = m_offset[:, None] < M and n_offset[None, :] < N 

81 out_grad_ptrs = out_grad_ptr + offsets 

82 out_grad = tl.load(out_grad_ptrs, mask=mask).to(tl.float32) 

83 scale += out_grad 

84 scale = tl.sum(scale, 1) 

85 

86 for start_n in range(0, N, BLOCK_N): 

87 n_offset = start_n + tl.arange(0, BLOCK_N) 

88 offsets = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k 

89 mask = m_offset[:, None] < M and n_offset[None, :] < N 

90 out_ptrs = out_ptr + offsets 

91 out = tl.load(out_ptrs, mask=mask).to(tl.float32) 

92 out_grad_ptrs = out_grad_ptr + offsets 

93 out_grad = tl.load(out_grad_ptrs, mask=mask).to(tl.float32) 

94 in_grad = out_grad - tl.exp(out) * scale[:, None] 

95 in_grad_ptrs = in_grad_ptr + offsets 

96 tl.store(in_grad_ptrs, in_grad, mask=mask) 

97 

98 

99def log_softmax(self, dim, half_to_float=False): 

100 logger.debug("GEMS_ASCEND LOG_SOFTMAX") 

101 

102 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim" 

103 dim = dim % self.ndim 

104 M = 1 

105 N = self.shape[dim] 

106 for i in range(dim): 

107 M *= self.shape[i] 

108 inp = self.contiguous() 

109 if half_to_float: 

110 dtype = torch.float32 

111 else: 

112 dtype = self.dtype 

113 out = torch.empty_like(inp, dtype=dtype) 

114 K = inp.numel() // M // N 

115 

116 grid = lambda meta: ( 

117 triton.cdiv(M, meta["BLOCK_M"]), 

118 K, 

119 ) 

120 with torch_device_fn.device(inp.device): 

121 log_softmax_kernel[grid]( 

122 out, 

123 inp, 

124 M, 

125 N, 

126 K, 

127 ) 

128 return out 

129 

130 

131def log_softmax_backward(grad_output, output, dim, input_dtype): 

132 logger.debug("GEMS_ASCEND LOG_SOFTMAX VJP") 

133 

134 assert dim >= -output.ndim and dim < output.ndim, "Invalid dim" 

135 dim = dim % output.ndim 

136 M = 1 

137 N = output.shape[dim] 

138 for i in range(dim): 

139 M *= output.shape[i] 

140 

141 grad_output = grad_output.contiguous() 

142 in_grad = torch.empty_like(output, dtype=input_dtype) 

143 K = output.numel() // M // N 

144 

145 grid = lambda meta: ( 

146 triton.cdiv(M, meta["BLOCK_M"]), 

147 K, 

148 ) 

149 with torch_device_fn.device(in_grad.device): 

150 log_softmax_backward_kernel[grid]( 

151 output, 

152 grad_output, 

153 in_grad, 

154 M, 

155 N, 

156 K, 

157 ) 

158 return in_grad