Coverage for src/flag_gems/runtime/backend/_tsingmicro/ops/unique.py: 0%

91 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-22 16:54 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5from flag_gems.runtime import torch_device_fn 

6from flag_gems.utils.libentry import libentry 

7 

8TOTAL_CORE_NUM = torch_device_fn.get_device_properties().multi_processor_count 

9 

10 

11@libentry() 

12@triton.autotune( 

13 configs=[ 

14 triton.Config({"BLOCK_SIZE": 2**k}, num_stages=s, num_warps=1) 

15 for k in range(11, 17, 1) 

16 for s in [1, 3] 

17 ], 

18 key=[ 

19 "tile_size", 

20 ], 

21) 

22@triton.jit 

23def get_ne_kernel( 

24 sorted_data_ptr: tl.tensor, 

25 sorted_data_2: tl.tensor, 

26 ne_out_ptr: tl.tensor, 

27 tile_size: tl.constexpr, 

28 BLOCK_SIZE: tl.constexpr, 

29): 

30 pid = tl.program_id(axis=0) 

31 num_jobs = tl.num_programs(axis=0) 

32 split_n = (tile_size + num_jobs - 1) // num_jobs 

33 start_offset = pid * split_n 

34 i0 = tl.arange(0, BLOCK_SIZE) 

35 

36 for i in range(0, split_n, BLOCK_SIZE): 

37 offset = start_offset + i + i0 

38 mask = offset < tile_size 

39 a = tl.load(sorted_data_ptr + offset, mask=mask) 

40 b = tl.load(sorted_data_2 + offset, mask=mask) 

41 # ne 

42 ne_result = (offset > 0) * (a != b) 

43 tl.store(ne_out_ptr + offset, ne_result, mask=mask) 

44 

45 

46@libentry() 

47@triton.autotune( 

48 configs=[ 

49 triton.Config({"BLOCK_SIZE": k}, num_stages=s, num_warps=1) 

50 for k in [32, 256, 1024, 2048, 4096] 

51 for s in [1, 3] 

52 ], 

53 key=[ 

54 "tile_size", 

55 ], 

56) 

57@triton.jit 

58def get_unique_out_kernel( 

59 sorted_data_ptr: tl.tensor, 

60 sorted_indices_ptr: tl.tensor, # in 

61 ne_result_ptr: tl.tensor, 

62 pre_sum_ptr: tl.tensor, 

63 idx_ptr: tl.tensor, 

64 data_out_ptr: tl.tensor, 

65 inverse_indices_ptr: tl.tensor, 

66 return_inverse: tl.constexpr, 

67 return_counts: tl.constexpr, 

68 tile_size: tl.constexpr, 

69 BLOCK_SIZE: tl.constexpr, 

70): 

71 pid = tl.program_id(axis=0) 

72 num_jobs = tl.num_programs(axis=0) 

73 

74 split_n = (tile_size + num_jobs - 1) // num_jobs 

75 start_offset = pid * split_n 

76 i0 = tl.arange(0, BLOCK_SIZE) 

77 

78 for i in range(0, split_n, BLOCK_SIZE): 

79 offset = start_offset + i + i0 

80 mask = offset < tile_size 

81 sorted_data = tl.load(sorted_data_ptr + offset, mask=mask) 

82 pre_sum_data = tl.load(pre_sum_ptr + offset, mask=mask) 

83 

84 # data_out: scatter_(to=pre_sum_data, sorted_data) 

85 tl.store(data_out_ptr + pre_sum_data, sorted_data, mask=mask) 

86 

87 # inverse_indices: scatter_(to=sorted_indices, pre_sum_data) 

88 if return_inverse: 

89 sorted_indices = tl.load(sorted_indices_ptr + offset, mask=mask) 

90 tl.store(inverse_indices_ptr + sorted_indices, pre_sum_data, mask=mask) 

91 

92 # idx: mark positions of unique values in idx_ptr 

93 if return_counts: 

94 ne_result = tl.load(ne_result_ptr + offset, mask=mask) 

95 idx_mask = ((offset == 0) | ne_result.to(tl.int1)) & mask 

96 tl.store(idx_ptr + pre_sum_data, offset, mask=idx_mask) 

97 

98 

99@triton.autotune( 

100 configs=[ 

101 triton.Config({"BLOCK_SIZE": 2**k}, num_stages=s, num_warps=1) 

102 for k in range(7, 14, 1) 

103 for s in [1, 3] 

104 ], 

105 key=[ 

106 "tile_size", 

107 ], 

108) 

109@triton.jit 

110def get_output_counts_kernel( 

111 idx_ptr: tl.tensor, 

112 idx_next_ptr: tl.tensor, 

113 counts_ptr: tl.tensor, # out 

114 tile_size: tl.constexpr, 

115 BLOCK_SIZE: tl.constexpr, 

116): 

117 pid = tl.program_id(axis=0) 

118 num_jobs = tl.num_programs(axis=0) 

119 split_n = (tile_size + num_jobs - 1) // num_jobs 

120 start_offset = pid * split_n 

121 

122 i0 = tl.arange(0, BLOCK_SIZE) 

123 

124 for i in range(0, split_n, BLOCK_SIZE): 

125 offset = start_offset + i + i0 

126 mask = offset < tile_size 

127 # load idx 

128 idx = tl.load(idx_ptr + offset, mask=mask) 

129 # load idx_next 

130 idx_next = tl.load(idx_next_ptr + offset, mask=mask) 

131 # diff 

132 counts = idx_next - idx 

133 # store counts 

134 tl.store(counts_ptr + offset, counts, mask=mask) 

135 

136 

137def sorted_unique_flat( 

138 sorted_data: torch.Tensor, 

139 sorted_indices: torch.Tensor, 

140 return_inverse: bool, 

141 return_counts: bool, 

142): 

143 num_tasks = sorted_data.numel() 

144 grid = lambda meta: ( 

145 min(triton.cdiv(num_tasks, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM), 

146 ) 

147 

148 # allocate tensor 

149 ne_out = torch.empty_like(sorted_data, dtype=torch.bool) 

150 data_out = torch.empty_like(sorted_data) 

151 if return_inverse: 

152 inverse_indices = torch.empty_like(sorted_data, dtype=torch.int64) 

153 else: 

154 inverse_indices = None 

155 if return_counts: 

156 idx = torch.empty_like(sorted_data, dtype=torch.int64) 

157 else: 

158 idx = None 

159 sorted_data_2 = torch.empty_like(sorted_data) 

160 sorted_data_2[1:] = sorted_data[:-1] 

161 

162 # launch kernel 

163 with torch_device_fn.device(sorted_data.device.index): 

164 get_ne_kernel[grid]( 

165 sorted_data, 

166 sorted_data_2, 

167 ne_out, 

168 tile_size=num_tasks, 

169 ) 

170 pre_sum = ne_out.cumsum(axis=0) 

171 get_unique_out_kernel[grid]( 

172 sorted_data, 

173 sorted_indices, 

174 ne_out, 

175 pre_sum, 

176 idx, 

177 data_out, 

178 inverse_indices, 

179 return_inverse, 

180 return_counts, 

181 tile_size=num_tasks, 

182 ) 

183 

184 out_size = pre_sum[-1].item() + 1 

185 counts = None 

186 if return_counts: 

187 idx = idx[:out_size] 

188 sorted_data_size = len(sorted_data) 

189 idx_next = torch.roll(idx, -1) 

190 idx_next[-1] = sorted_data_size 

191 counts = torch.zeros_like(idx) 

192 with torch_device_fn.device(sorted_data.device.index): 

193 get_output_counts_kernel[grid]( 

194 idx, 

195 idx_next, 

196 counts, # out 

197 tile_size=out_size, 

198 ) 

199 return data_out[:out_size], inverse_indices, counts 

200 

201 

202def _unique2( 

203 in0: torch.Tensor, 

204 sorted: bool = True, 

205 return_inverse: bool = False, 

206 return_counts: bool = False, 

207): 

208 sorted_data, sorted_indices = torch.sort(in0.ravel(), stable=False) 

209 data_out, inverse_indices, counts = sorted_unique_flat( 

210 sorted_data, sorted_indices, return_inverse, return_counts 

211 ) 

212 return ( 

213 data_out, 

214 inverse_indices if inverse_indices is None else inverse_indices.view_as(in0), 

215 counts, 

216 )