Coverage for src/flag_gems/ops/log_softmax.py: 49%

98 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-11 02:28 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15@libentry() 

16@triton.jit 

17def log_softmax_kernel( 

18 output_ptr, 

19 input_ptr, 

20 M, 

21 N, 

22 K, 

23 BLOCK_M: tl.constexpr = 8, 

24 BLOCK_N: tl.constexpr = 256, 

25): 

26 pid_m = tle.program_id(0) 

27 pid_k = tle.program_id(1) 

28 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

29 

30 # TODO(chenfeiyu): consider float64 add add a utility function to get accumulator type 

31 m = tl.full([BLOCK_M, BLOCK_N], value=float("-inf"), dtype=tl.float32) 

32 z = tl.full([BLOCK_M, BLOCK_N], value=0.0, dtype=tl.float32) 

33 for start_n in range(0, N, BLOCK_N): 

34 n_offset = start_n + tl.arange(0, BLOCK_N) 

35 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k 

36 mask = m_offset[:, None] < M and n_offset[None, :] < N 

37 input_ptrs = input_ptr + offset 

38 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32) 

39 m_new = tl.maximum(inp, m) 

40 all_neg_inf = m_new == float("-inf") 

41 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new)) 

42 m = m_new 

43 

44 m_reduced = tl.max(m, 1) 

45 z = tl.sum(z * tl.exp(m - m_reduced[:, None]), 1) 

46 m = m_reduced 

47 

48 for start_n in range(0, N, BLOCK_N): 

49 n_offset = start_n + tl.arange(0, BLOCK_N) 

50 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k 

51 mask = m_offset[:, None] < M and n_offset[None, :] < N 

52 input_ptrs = input_ptr + offset 

53 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32) 

54 o = inp - m[:, None] - tl.log(z[:, None]) 

55 tl.store(output_ptr + offset, o, mask=mask) 

56 

57 

58@libentry() 

59@triton.autotune(configs=runtime.get_tuned_config("log_softmax"), key=["M", "N"]) 

60@triton.jit 

61def log_softmax_backward_kernel( 

62 out_ptr, 

63 out_grad_ptr, 

64 in_grad_ptr, 

65 M, 

66 N, 

67 K, 

68 BLOCK_M: tl.constexpr, 

69 BLOCK_N: tl.constexpr, 

70): 

71 pid_m = tle.program_id(0) 

72 pid_k = tle.program_id(1) 

73 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

74 

75 scale = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

76 for start_n in range(0, N, BLOCK_N): 

77 n_offset = start_n + tl.arange(0, BLOCK_N) 

78 offsets = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k 

79 mask = m_offset[:, None] < M and n_offset[None, :] < N 

80 out_grad_ptrs = out_grad_ptr + offsets 

81 out_grad = tl.load(out_grad_ptrs, mask=mask).to(tl.float32) 

82 scale += out_grad 

83 scale = tl.sum(scale, 1) 

84 

85 for start_n in range(0, N, BLOCK_N): 

86 n_offset = start_n + tl.arange(0, BLOCK_N) 

87 offsets = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k 

88 mask = m_offset[:, None] < M and n_offset[None, :] < N 

89 out_ptrs = out_ptr + offsets 

90 out = tl.load(out_ptrs, mask=mask).to(tl.float32) 

91 out_grad_ptrs = out_grad_ptr + offsets 

92 out_grad = tl.load(out_grad_ptrs, mask=mask).to(tl.float32) 

93 in_grad = out_grad - tl.exp(out) * scale[:, None] 

94 in_grad_ptrs = in_grad_ptr + offsets 

95 tl.store(in_grad_ptrs, in_grad, mask=mask) 

96 

97 

98def log_softmax(self, dim, half_to_float=False): 

99 logger.debug("GEMS LOG_SOFTMAX") 

100 

101 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim" 

102 dim = dim % self.ndim 

103 M = 1 

104 N = self.shape[dim] 

105 for i in range(dim): 

106 M *= self.shape[i] 

107 inp = self.contiguous() 

108 if half_to_float: 

109 dtype = torch.float32 

110 else: 

111 dtype = self.dtype 

112 out = torch.empty_like(inp, dtype=dtype) 

113 K = inp.numel() // M // N 

114 

115 grid = lambda meta: ( 

116 triton.cdiv(M, meta["BLOCK_M"]), 

117 K, 

118 ) 

119 with torch_device_fn.device(inp.device): 

120 log_softmax_kernel[grid]( 

121 out, 

122 inp, 

123 M, 

124 N, 

125 K, 

126 num_warps=8, 

127 ) 

128 return out 

129 

130 

131def log_softmax_backward(grad_output, output, dim, input_dtype): 

132 logger.debug("GEMS LOG_SOFTMAX VJP") 

133 

134 assert dim >= -output.ndim and dim < output.ndim, "Invalid dim" 

135 dim = dim % output.ndim 

136 M = 1 

137 N = output.shape[dim] 

138 for i in range(dim): 

139 M *= output.shape[i] 

140 

141 grad_output = grad_output.contiguous() 

142 in_grad = torch.empty_like(output, dtype=input_dtype) 

143 K = output.numel() // M // N 

144 

145 grid = lambda meta: ( 

146 triton.cdiv(M, meta["BLOCK_M"]), 

147 K, 

148 ) 

149 with torch_device_fn.device(in_grad.device): 

150 log_softmax_backward_kernel[grid]( 

151 output, 

152 grad_output, 

153 in_grad, 

154 M, 

155 N, 

156 K, 

157 ) 

158 return in_grad