Coverage for src/flag_gems/runtime/backend/_metax/ops/repeat_interleave.py: 0%

117 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-26 15:32 +0800

1import logging 

2 

3import torch 

4import triton 

5from triton import language as tl 

6 

7from flag_gems.utils import triton_lang_extension as tle 

8from flag_gems.utils.pointwise_dynamic import pointwise_dynamic 

9from flag_gems.utils.shape_utils import c_contiguous_stride 

10from flag_gems.utils.tensor_wrapper import StridedBuffer 

11 

12logger = logging.getLogger("flag_gems." + __name__) 

13 

14 

15@pointwise_dynamic(num_inputs=1, promotion_methods=[(0, "DEFAULT")]) 

16@triton.jit 

17def copy_func(x): 

18 return x 

19 

20 

21def repeat_interleave_self_int(inp, repeats, dim=None, *, output_size=None): 

22 logger.debug("METAX GEMS REPEAT_INTERLEAVE_SELF_INT") 

23 if dim is None: 

24 inp = inp.flatten() 

25 dim = 0 

26 else: 

27 if (dim < -inp.ndim) or (dim >= inp.ndim): 

28 raise IndexError( 

29 "Dimension out of range (expected to be in range of [{}, {}], but got {})".format( 

30 -inp.ndim, inp.ndim - 1, dim 

31 ) 

32 ) 

33 inp_shape = list(inp.shape) 

34 inp_stride = list(inp.stride()) 

35 output_shape = list(inp.shape) 

36 

37 if dim < 0: 

38 dim = dim + len(inp_shape) 

39 

40 output_shape[dim] *= repeats 

41 

42 if output_size is not None and output_size != output_shape[dim]: 

43 raise RuntimeError( 

44 "repeat_interleave: Invalid output_size, expected {} but got {}".format( 

45 output_shape[dim], output_size 

46 ) 

47 ) 

48 

49 output = torch.empty(output_shape, dtype=inp.dtype, device=inp.device) 

50 

51 if repeats == 0: 

52 return output 

53 

54 in_view_stride = inp_stride[: dim + 1] + [0] + inp_stride[dim + 1 :] 

55 out_view_shape = inp_shape[: dim + 1] + [repeats] + inp_shape[dim + 1 :] 

56 out_view_stride = c_contiguous_stride(out_view_shape) 

57 

58 in_view = StridedBuffer(inp, out_view_shape, in_view_stride) 

59 out_view = StridedBuffer(output, out_view_shape, out_view_stride) 

60 ndim = len(out_view_shape) 

61 copy_func.instantiate(ndim)(in_view, out0=out_view) 

62 return output 

63 

64 

65@triton.jit 

66def repeat_interleave_tensor_kernel( 

67 repeats_ptr, cumsum_ptr, out_ptr, size, BLOCK_SIZE: tl.constexpr 

68): 

69 pid = tle.program_id(0) 

70 mask = pid < size 

71 cumsum = tl.load(cumsum_ptr + pid, mask, other=0) 

72 repeats = tl.load(repeats_ptr + pid, mask, other=0) 

73 out_offset = cumsum - repeats 

74 

75 tl.device_assert(repeats >= 0, "repeats can not be negative") 

76 

77 out_ptr += out_offset 

78 for start_k in range(0, repeats, BLOCK_SIZE): 

79 offsets_k = start_k + tl.arange(0, BLOCK_SIZE) 

80 mask_k = offsets_k < repeats 

81 tl.store(out_ptr + offsets_k, pid, mask=mask_k) 

82 

83 

84@triton.jit 

85def fused_repeat_and_index_select_kernel( 

86 inp, out, M, N, repeats, BLOCK_SIZE: tl.constexpr 

87): 

88 pid = tle.program_id(0) 

89 row_idx_mask = pid > 0 

90 start_row_idx = tl.load(repeats + pid - 1, mask=row_idx_mask, other=0) 

91 end_row_idx = tl.load(repeats + pid) 

92 

93 num_of_rows = end_row_idx - start_row_idx 

94 if num_of_rows == 0: 

95 return 

96 

97 inp_row_offset = pid * M 

98 for m in range(0, tl.cdiv(M, BLOCK_SIZE)): 

99 cols_offsets = m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

100 mask = cols_offsets < M 

101 cur_inp = tl.load(inp + inp_row_offset + cols_offsets, mask=mask, other=0.0) 

102 

103 for cur_row_in_pid in range(0, num_of_rows): 

104 output_row_index = start_row_idx + cur_row_in_pid 

105 output_row_offsets = output_row_index * M + cols_offsets 

106 tl.store(out + output_row_offsets, cur_inp, mask=mask) 

107 

108 

109def repeat_interleave_tensor(repeats, *, output_size=None): 

110 logger.debug("METAX GEMS REPEAT_INTERLEAVE_TENSOR") 

111 

112 assert repeats.ndim == 1, "repeat_interleave only accept 1D vector as repeat" 

113 

114 cumsum = repeats.cumsum(axis=0) 

115 result_size = cumsum[-1].item() 

116 

117 assert result_size >= 0, "repeats can not be negative" 

118 

119 out = torch.empty((result_size,), dtype=repeats.dtype, device=repeats.device) 

120 size = repeats.size(0) 

121 

122 grid = (size,) 

123 BLOCK_SIZE = 32 

124 repeat_interleave_tensor_kernel[grid]( 

125 repeats, 

126 cumsum, 

127 out, 

128 size, 

129 BLOCK_SIZE=BLOCK_SIZE, 

130 num_warps=1, 

131 ) 

132 return out 

133 

134 

135def fused_repeat_and_index_select(inp, repeats, dim): 

136 logger.debug("METAX GEMS FUSED_REPEAT_AND_INDEX_SELECT") 

137 

138 assert repeats.ndim == 1, "repeat_interleave only accept 1D vector as repeat" 

139 

140 repeats = repeats.cumsum(axis=0) 

141 index_len = repeats[-1].item() 

142 

143 out_shape = list(inp.shape) 

144 out_shape[dim] = index_len 

145 out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device) 

146 

147 N = inp.shape[dim] 

148 M = inp.numel() // N 

149 

150 grid = (inp.shape[dim],) 

151 BLOCK_SIZE = min(triton.next_power_of_2(M), 4096) 

152 

153 fused_repeat_and_index_select_kernel[grid]( 

154 inp, out, M, N, repeats, BLOCK_SIZE=BLOCK_SIZE 

155 ) 

156 return out 

157 

158 

159def repeat_interleave_self_tensor(inp, repeats, dim=None, *, output_size=None): 

160 logger.debug("METAX GEMS REPEAT_INTERLEAVE_SELF_TENSOR") 

161 

162 if dim is None: 

163 inp = inp.flatten() 

164 dim = 0 

165 else: 

166 if (dim < -inp.ndim) or (dim >= inp.ndim): 

167 raise IndexError( 

168 "Dimension out of range (expected to be in range of [{}, {}], but got {})".format( 

169 -inp.ndim, inp.ndim - 1, dim 

170 ) 

171 ) 

172 

173 if repeats.ndim == 0 or (repeats.ndim == 1 and repeats.size(0) == 1): 

174 return repeat_interleave_self_int( 

175 inp, repeats.item(), dim=dim, output_size=output_size 

176 ) 

177 elif repeats.ndim > 1: 

178 raise RuntimeError("repeats must be 0-dim or 1-dim tensor") 

179 

180 inp_shape = list(inp.shape) 

181 if dim < 0: 

182 dim = dim + len(inp_shape) 

183 

184 if repeats.size(0) != inp_shape[dim]: 

185 raise RuntimeError( 

186 "repeats must have the same size as input along dim, but got \ 

187 repeats.size(0) = {} and input.size({}) = {}".format( 

188 repeats.size(0), dim, inp_shape[dim] 

189 ) 

190 ) 

191 

192 if inp.ndim == 2 and dim == 0: 

193 res = fused_repeat_and_index_select(inp, repeats, dim) 

194 else: 

195 indices = repeat_interleave_tensor(repeats) 

196 res = torch.index_select(inp, dim, indices) 

197 

198 return res