Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/min.py: 0%

114 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1import logging 

2import math 

3from collections import namedtuple 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9# from flag_gems import runtime 

10from flag_gems.runtime import torch_device_fn 

11from flag_gems.utils import libentry 

12from flag_gems.utils import triton_lang_extension as tle 

13from flag_gems.utils.limits import get_dtype_max 

14 

15from ..utils.block_size_utils import get_block_size_1d 

16 

17logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

18 

19 

20@libentry() 

21@triton.jit 

22def min_kernel_1( 

23 inp, 

24 mid, 

25 M, 

26 BLOCK_SIZE: tl.constexpr, 

27): 

28 pid = tle.program_id(0) 

29 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

30 inp_ptrs = inp + offset 

31 mask = offset < M 

32 max_value = get_dtype_max(inp.type.element_ty) 

33 inp_val = tl.load(inp_ptrs, mask=mask, other=max_value) 

34 min_val = tl.min(inp_val) 

35 mid_ptr = mid + pid 

36 tl.store(mid_ptr, min_val) 

37 

38 

39@libentry() 

40@triton.jit 

41def min_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr): 

42 offset = tl.arange(0, BLOCK_MID) 

43 mid_ptrs = mid + offset 

44 mask = offset < mid_size 

45 max_value = get_dtype_max(mid.type.element_ty) 

46 mid_val = tl.load(mid_ptrs, mask=mask, other=max_value) 

47 min_val = tl.min(mid_val) 

48 tl.store(out, min_val) 

49 

50 

51def heur_m_block_size(args): 

52 return triton.next_power_of_2(triton.cdiv(args["M"], 12)) # cluster_num 

53 

54 

55def heur_n_block_size(args): 

56 import builtins 

57 

58 return builtins.min(triton.next_power_of_2(args["N"]), 8192) 

59 

60 

61@libentry() 

62# @triton.autotune( 

63# configs=runtime.get_tuned_config("min"), 

64# key=[ 

65# "M", 

66# "N", 

67# ], 

68# ) 

69@triton.heuristics( 

70 values={ 

71 "BLOCK_M": heur_m_block_size, 

72 "BLOCK_N": heur_n_block_size, 

73 }, 

74) 

75@triton.jit 

76def min_kernel( 

77 inp, 

78 out_value, 

79 out_index, 

80 M: tl.constexpr, 

81 N: tl.constexpr, 

82 K: tl.constexpr, 

83 BLOCK_M: tl.constexpr, 

84 BLOCK_N: tl.constexpr, 

85): 

86 # set offset 

87 pid_m = tle.program_id(0) 

88 pid_k = tle.program_id(1) 

89 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

90 

91 dtype = inp.type.element_ty 

92 # you just cannot create a function that return a tl.dtype in triton lang 

93 acc_type = tl.float32 if dtype is tl.bfloat16 else dtype 

94 max_value = get_dtype_max(dtype) 

95 min_values = tl.full([BLOCK_M], dtype=acc_type, value=max_value) 

96 argmin_values = tl.full([BLOCK_M], dtype=tl.int64, value=0) 

97 for start_n in range(0, N, BLOCK_N): 

98 n_offset = start_n + tl.arange(0, BLOCK_N) 

99 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k 

100 mask = m_offset[:, None] < M and n_offset[None, :] < N 

101 inp_ptrs = inp + offset 

102 inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value) 

103 local_min, local_argmin = tl.min(inp_vals, 1, return_indices=True) 

104 # if return indices is not supported, call a tl.argmax in addition 

105 # local_argmin = tl.argmin(inp_vals, 1) 

106 update = local_min < min_values 

107 min_values = tl.where(update, local_min, min_values) 

108 argmin_values = tl.where(update, start_n + local_argmin, argmin_values) 

109 

110 offset_index = m_offset * K + pid_k 

111 out_value_ptrs = out_value + offset_index 

112 out_index_ptrs = out_index + offset_index 

113 mask1 = m_offset < M 

114 tl.store(out_value_ptrs, min_values, mask=mask1) 

115 tl.store(out_index_ptrs, argmin_values, mask=mask1) 

116 

117 

118def min(inp): 

119 logger.debug("GEMS MIN") 

120 M = inp.numel() 

121 # block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

122 block_size = get_block_size_1d(M, inp.element_size()) 

123 mid_size = triton.cdiv(M, block_size) 

124 block_mid = triton.next_power_of_2(mid_size) 

125 

126 dtype = inp.dtype 

127 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

128 out = torch.empty([], dtype=dtype, device=inp.device) 

129 

130 with torch_device_fn.device(inp.device): 

131 min_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size, buffer_size_limit=2048) 

132 if mid_size == 1: 

133 return mid.reshape([]) 

134 

135 import os 

136 

137 os.environ["TRITONXPU_OTHER_SIM"] = "1" 

138 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1" 

139 

140 min_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid, buffer_size_limit=2048) 

141 

142 if "TRITONXPU_OTHER_SIM" in os.environ: 

143 del os.environ["TRITONXPU_OTHER_SIM"] 

144 if "TRITONXPU_STORE_MASK_SIM" in os.environ: 

145 del os.environ["TRITONXPU_STORE_MASK_SIM"] 

146 return out 

147 

148 

149def min_dim(inp, dim=None, keepdim=False): 

150 logger.debug("GEMS MIN DIM") 

151 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

152 shape = inp.shape 

153 dim = dim % inp.ndim 

154 N = shape[dim] 

155 M = math.prod(shape[:dim]) 

156 K = inp.numel() // M // N 

157 

158 inp = inp.contiguous() 

159 

160 shape_list = list(shape) 

161 shape_list[dim] = 1 

162 out_value = torch.empty(shape_list, dtype=inp.dtype, device=inp.device) 

163 out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device) 

164 

165 if not keepdim: 

166 out_value = torch.squeeze(out_value, dim) 

167 out_index = torch.squeeze(out_index, dim) 

168 

169 grid = lambda meta: ( 

170 triton.cdiv(M, meta["BLOCK_M"]), 

171 K, 

172 ) 

173 isCloseCoreTiling = False 

174 if inp.dtype in [torch.int16, torch.int32, torch.int64] and M == 4096 and N == 256: 

175 isCloseCoreTiling = True 

176 with torch_device_fn.device(inp.device): 

177 min_kernel[grid]( 

178 inp, out_value, out_index, M, N, K, isCloseCoreTiling=isCloseCoreTiling 

179 ) 

180 Min_out = namedtuple("min", ["values", "indices"]) 

181 out = Min_out(values=out_value, indices=out_index) 

182 return out