Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/max.py: 0%

127 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-12 02:21 +0800

1import logging 

2import math 

3import os 

4from collections import namedtuple 

5 

6import torch 

7import triton 

8import triton.language as tl 

9 

10# from flag_gems import runtime 

11from flag_gems.runtime import torch_device_fn 

12from flag_gems.utils import libentry 

13from flag_gems.utils import triton_lang_extension as tle 

14from flag_gems.utils.limits import get_dtype_min 

15 

16from ..utils.block_size_utils import get_block_size_1d 

17 

18logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

19 

20 

21@libentry() 

22@triton.jit 

23def max_kernel_1( 

24 inp, 

25 mid, 

26 M, 

27 BLOCK_SIZE: tl.constexpr, 

28): 

29 pid = tle.program_id(0) 

30 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

31 inp_ptrs = inp + offset 

32 mask = offset < M 

33 min_value = get_dtype_min(inp.type.element_ty) 

34 inp_val = tl.load(inp_ptrs, mask=mask, other=min_value) 

35 max_val = tl.max(inp_val) 

36 mid_ptr = mid + pid 

37 tl.store(mid_ptr, max_val) 

38 

39 

40@libentry() 

41@triton.jit 

42def max_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr): 

43 offset = tl.arange(0, BLOCK_MID) 

44 mid_ptrs = mid + offset 

45 mask = offset < mid_size 

46 min_value = get_dtype_min(mid.type.element_ty) 

47 mid_val = tl.load(mid_ptrs, mask=mask, other=min_value) 

48 max_val = tl.max(mid_val) 

49 tl.store(out, max_val) 

50 

51 

52def heur_m_block_size(args): 

53 return triton.next_power_of_2(triton.cdiv(args["M"], 12)) # cluster_num 

54 

55 

56def heur_n_block_size(args): 

57 import builtins 

58 

59 return builtins.min(triton.next_power_of_2(args["N"]), 8192) 

60 

61 

62# def heur_m_block_size(args): 

63# # if triton.next_power_of_2(triton.cdiv(args["M"], cluster_num)) < core_num: 

64# # return triton.next_power_of_2(triton.cdiv(args["M"], cluster_num)) 

65# # else: 

66# return ( 

67# triton.cdiv(triton.cdiv(2048, args["ELEMENT_SIZE"]), args["N"]) 

68# * 64 

69# ) 

70 

71 

72# def heur_n_block_size(args): 

73# return min(args["N"], triton.cdiv(2048, args["ELEMENT_SIZE"])) 

74 

75 

76@libentry() 

77# @triton.autotune( 

78# configs=runtime.get_tuned_config("max"), 

79# key=[ 

80# "M", 

81# "N", 

82# ], 

83# ) 

84@triton.heuristics( 

85 values={ 

86 "BLOCK_M": heur_m_block_size, 

87 "BLOCK_N": heur_n_block_size, 

88 }, 

89) 

90@triton.jit 

91def max_kernel( 

92 inp, 

93 out_value, 

94 out_index, 

95 M: tl.constexpr, 

96 N: tl.constexpr, 

97 K: tl.constexpr, 

98 ELEMENT_SIZE: tl.constexpr, 

99 BLOCK_M: tl.constexpr, 

100 BLOCK_N: tl.constexpr, 

101): 

102 # set offset 

103 pid_m = tle.program_id(0) 

104 pid_k = tle.program_id(1) 

105 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

106 

107 dtype = inp.type.element_ty 

108 acc_type = tl.float32 if dtype is tl.bfloat16 else dtype 

109 min_value = get_dtype_min(dtype) 

110 result_value = tl.full([BLOCK_M], value=min_value, dtype=acc_type) 

111 result_index = tl.zeros([BLOCK_M], dtype=tl.int64) 

112 for i in range(0, N, BLOCK_N): 

113 n_offset = i + tl.arange(0, BLOCK_N) 

114 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k 

115 # set mask 

116 mask = m_offset[:, None] < M and n_offset[None, :] < N 

117 inp_ptrs = inp + offset 

118 inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value) 

119 max_value, max_index = tl.max(inp_vals, axis=1, return_indices=True) 

120 update_mask = max_value > result_value 

121 result_value = tl.where(update_mask, max_value, result_value) 

122 result_index = tl.where(update_mask, i + max_index, result_index) 

123 mask1 = m_offset < M 

124 offset_index = m_offset * K + pid_k 

125 out_value_ptrs = out_value + offset_index 

126 out_index_ptrs = out_index + offset_index 

127 

128 tl.store(out_value_ptrs, result_value, mask=mask1) 

129 tl.store(out_index_ptrs, result_index, mask=mask1) 

130 

131 

132def max(inp): 

133 logger.debug("GEMS MAX") 

134 os.environ["TRITONXPU_IS_SCATTER_SLICE"] = "1" 

135 inp = inp.contiguous() 

136 M = inp.numel() 

137 # block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

138 block_size = get_block_size_1d(M, inp.element_size()) 

139 mid_size = triton.cdiv(M, block_size) 

140 block_mid = triton.next_power_of_2(mid_size) 

141 

142 dtype = inp.dtype 

143 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

144 out = torch.empty([], dtype=dtype, device=inp.device) 

145 if M == 1: 

146 return inp.reshape([]) 

147 with torch_device_fn.device(inp.device): 

148 max_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size, buffer_size_limit=2048) 

149 if mid_size == 1: 

150 return mid.reshape([]) 

151 

152 os.environ["TRITONXPU_OTHER_SIM"] = "1" 

153 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1" 

154 

155 max_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid, buffer_size_limit=2048) 

156 

157 if "TRITONXPU_OTHER_SIM" in os.environ: 

158 del os.environ["TRITONXPU_OTHER_SIM"] 

159 if "TRITONXPU_STORE_MASK_SIM" in os.environ: 

160 del os.environ["TRITONXPU_STORE_MASK_SIM"] 

161 

162 if "TRITONXPU_IS_SCATTER_SLICE" in os.environ: 

163 del os.environ["TRITONXPU_IS_SCATTER_SLICE"] 

164 return out 

165 

166 

167def max_dim(inp, dim=None, keepdim=False): 

168 logger.debug("GEMS MAX DIM") 

169 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

170 shape = inp.shape 

171 dim = dim % inp.ndim 

172 N = shape[dim] 

173 M = math.prod(shape[:dim]) 

174 K = inp.numel() // M // N 

175 ELEMENT_SIZE = inp.element_size() 

176 

177 inp = inp.contiguous() 

178 

179 shape_list = list(shape) 

180 shape_list[dim] = 1 

181 out_value = torch.empty(shape_list, dtype=inp.dtype, device=inp.device) 

182 out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device) 

183 

184 if not keepdim: 

185 out_value = torch.squeeze(out_value, dim) 

186 out_index = torch.squeeze(out_index, dim) 

187 

188 grid = lambda meta: ( 

189 triton.cdiv(M, meta["BLOCK_M"]), 

190 K, 

191 ) 

192 os.environ["TRITONXPU_OTHER_SIM"] = "1" 

193 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1" 

194 isCloseCoreTiling = False 

195 if inp.dtype in [torch.int16, torch.int32, torch.int64] and M == 4096 and N == 256: 

196 isCloseCoreTiling = True 

197 

198 with torch_device_fn.device(inp.device): 

199 max_kernel[grid]( 

200 inp, 

201 out_value, 

202 out_index, 

203 M, 

204 N, 

205 K, 

206 ELEMENT_SIZE, 

207 isCloseCoreTiling=isCloseCoreTiling, 

208 ) 

209 

210 if "TRITONXPU_OTHER_SIM" in os.environ: 

211 del os.environ["TRITONXPU_OTHER_SIM"] 

212 if "TRITONXPU_STORE_MASK_SIM" in os.environ: 

213 del os.environ["TRITONXPU_STORE_MASK_SIM"] 

214 Max_out = namedtuple("max", ["values", "indices"]) 

215 out = Max_out(values=out_value, indices=out_index) 

216 return out