Coverage for src/flag_gems/runtime/backend/_metax/ops/log_softmax.py: 0%

98 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-26 15:32 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger("flag_gems." + __name__) 

13 

14 

15def heur_block_n(args): 

16 return triton.next_power_of_2(args["N"]) 

17 

18 

19def heur_num_warps(args): 

20 if args["N"] <= 1024: 

21 return 1 

22 elif args["N"] <= 2048: 

23 return 4 

24 else: 

25 return 8 

26 

27 

28@libentry() 

29@triton.autotune(configs=runtime.get_tuned_config("log_softmax"), key=["M", "N"]) 

30@triton.heuristics( 

31 { 

32 "BLOCK_N": heur_block_n, 

33 "num_warps": heur_num_warps, 

34 } 

35) 

36@triton.jit 

37def log_softmax_kernel( 

38 output_ptr, 

39 input_ptr, 

40 M, 

41 N, 

42 K, 

43 BLOCK_M: tl.constexpr, 

44 BLOCK_N: tl.constexpr, 

45 USE_K: tl.constexpr, 

46): 

47 pid_m = tle.program_id(0) 

48 pid_k = tle.program_id(1) 

49 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

50 n_offset = tl.arange(0, BLOCK_N) 

51 offset = m_offset[:, None] * N * K + n_offset[None, :] * K 

52 if USE_K: 

53 offset += pid_k 

54 mask = m_offset[:, None] < M and n_offset[None, :] < N 

55 input_ptrs = input_ptr + offset 

56 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32) 

57 row_minus_max = inp - tl.max(inp, axis=1)[:, None] 

58 numerator = tl.exp(row_minus_max) 

59 denominator = tl.sum(numerator, axis=1)[:, None] 

60 softmax_output = tl.log(numerator / denominator) 

61 output_ptrs = output_ptr + offset 

62 tl.store(output_ptrs, softmax_output, mask=mask) 

63 

64 

65@libentry() 

66@triton.autotune(configs=runtime.get_tuned_config("log_softmax"), key=["M", "N"]) 

67@triton.heuristics( 

68 { 

69 "BLOCK_N": heur_block_n, 

70 "num_warps": heur_num_warps, 

71 } 

72) 

73@triton.jit 

74def log_softmax_backward_kernel( 

75 out_ptr, 

76 out_grad_ptr, 

77 in_grad_ptr, 

78 M, 

79 N, 

80 K, 

81 BLOCK_M: tl.constexpr, 

82 BLOCK_N: tl.constexpr, 

83 BLOCK_N_SPLIT: tl.constexpr, 

84): 

85 pid_m = tle.program_id(0) 

86 pid_k = tle.program_id(1) 

87 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

88 n_split_offset = tl.arange(0, BLOCK_N_SPLIT) 

89 n_offset = tl.arange(0, BLOCK_N) 

90 all_offsets = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k 

91 out_grad_ptrs_all = out_grad_ptr + all_offsets 

92 all_mask = m_offset[:, None] < M and n_offset[None, :] < N 

93 out_grad_all = tl.load(out_grad_ptrs_all, mask=all_mask).to(tl.float32) 

94 scale = tl.sum(out_grad_all, 1) 

95 # use for loop to split N dim to reduce register cost 

96 for n in range(0, tl.cdiv(BLOCK_N, BLOCK_N_SPLIT)): 

97 offsets = ( 

98 m_offset[:, None] * N * K 

99 + n_split_offset[None, :] * K 

100 + n * BLOCK_N_SPLIT * K 

101 + pid_k 

102 ) 

103 mask = m_offset[:, None] < M and n_split_offset[None, :] + n * BLOCK_N_SPLIT < N 

104 out_ptrs = out_ptr + offsets 

105 out = tl.load(out_ptrs, mask=mask).to(tl.float32) 

106 exp_out = tl.exp(out.to(tl.float32)) 

107 out_grad_ptrs = out_grad_ptr + offsets 

108 out_grad = tl.load(out_grad_ptrs, mask=mask).to(tl.float32) 

109 

110 in_grad = out_grad - exp_out * scale[:, None] 

111 in_grad_ptrs = in_grad_ptr + offsets 

112 tl.store(in_grad_ptrs, in_grad, mask=mask) 

113 

114 

115def log_softmax(self, dim, half_to_float=False): 

116 logger.debug("METAX GEMS LOG_SOFTMAX") 

117 

118 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim" 

119 dim = dim % self.ndim 

120 M = 1 

121 N = self.shape[dim] 

122 for i in range(dim): 

123 M *= self.shape[i] 

124 inp = self.contiguous() 

125 if half_to_float: 

126 dtype = torch.float32 

127 else: 

128 dtype = self.dtype 

129 out = torch.empty_like(inp, dtype=dtype) 

130 K = inp.numel() // M // N 

131 USE_K = K != 1 

132 

133 grid = lambda meta: ( 

134 triton.cdiv(M, meta["BLOCK_M"]), 

135 K, 

136 ) 

137 with torch_device_fn.device(inp.device): 

138 log_softmax_kernel[grid]( 

139 out, 

140 inp, 

141 M, 

142 N, 

143 K, 

144 USE_K=USE_K, 

145 ) 

146 return out 

147 

148 

149def log_softmax_backward(grad_output, output, dim, input_dtype): 

150 logger.debug("METAX GEMS LOG_SOFTMAX VJP") 

151 

152 assert dim >= -output.ndim and dim < output.ndim, "Invalid dim" 

153 dim = dim % output.ndim 

154 M = 1 

155 N = output.shape[dim] 

156 for i in range(dim): 

157 M *= output.shape[i] 

158 

159 grad_output = grad_output.contiguous() 

160 in_grad = torch.empty_like(output, dtype=input_dtype) 

161 K = output.numel() // M // N 

162 

163 grid = lambda meta: ( 

164 triton.cdiv(M, meta["BLOCK_M"]), 

165 K, 

166 ) 

167 with torch_device_fn.device(in_grad.device): 

168 log_softmax_backward_kernel[grid]( 

169 output, 

170 grad_output, 

171 in_grad, 

172 M, 

173 N, 

174 K, 

175 BLOCK_N_SPLIT=1024, 

176 ) 

177 return in_grad