Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/repeat_interleave.py: 0%

83 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1import logging 

2 

3import torch 

4import triton 

5from triton import language as tl 

6 

7from flag_gems.utils import triton_lang_extension as tle 

8from flag_gems.utils.shape_utils import c_contiguous_stride 

9from flag_gems.utils.tensor_wrapper import StridedBuffer 

10 

11from ..utils.pointwise_dynamic import pointwise_dynamic 

12 

13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

14 

15 

16@pointwise_dynamic(num_inputs=1, promotion_methods=[(0, "DEFAULT")]) 

17@triton.jit 

18def copy_func(x): 

19 return x 

20 

21 

22def repeat_interleave_self_int(inp, repeats, dim=None, *, output_size=None): 

23 logger.debug("GEMS REPEAT_INTERLEAVE_SELF_INT") 

24 if dim is None: 

25 inp = inp.flatten() 

26 dim = 0 

27 else: 

28 if (dim < -inp.ndim) or (dim >= inp.ndim): 

29 raise IndexError( 

30 "Dimension out of range (expected to be in range of [{}, {}], but got {})".format( 

31 -inp.ndim, inp.ndim - 1, dim 

32 ) 

33 ) 

34 inp_shape = list(inp.shape) 

35 inp_stride = list(inp.stride()) 

36 output_shape = list(inp.shape) 

37 

38 if dim < 0: 

39 dim = dim + len(inp_shape) 

40 

41 output_shape[dim] *= repeats 

42 

43 if output_size is not None and output_size != output_shape[dim]: 

44 raise RuntimeError( 

45 "repeat_interleave: Invalid output_size, expected {} but got {}".format( 

46 output_shape[dim], output_size 

47 ) 

48 ) 

49 

50 output = torch.empty(output_shape, dtype=inp.dtype, device=inp.device) 

51 

52 if repeats == 0: 

53 return output 

54 

55 in_view_stride = inp_stride[: dim + 1] + [0] + inp_stride[dim + 1 :] 

56 out_view_shape = inp_shape[: dim + 1] + [repeats] + inp_shape[dim + 1 :] 

57 out_view_stride = c_contiguous_stride(out_view_shape) 

58 

59 in_view = StridedBuffer(inp, out_view_shape, in_view_stride) 

60 out_view = StridedBuffer(output, out_view_shape, out_view_stride) 

61 ndim = len(out_view_shape) 

62 copy_func.instantiate(ndim)(in_view, out0=out_view) 

63 return output 

64 

65 

66@triton.jit 

67def repeat_interleave_tensor_kernel( 

68 repeats_ptr, cumsum_ptr, out_ptr, size, BLOCK_SIZE: tl.constexpr 

69): 

70 pid = tle.program_id(0) 

71 mask = pid < size 

72 cumsum = tl.load(cumsum_ptr + pid, mask, other=0) 

73 repeats = tl.load(repeats_ptr + pid, mask, other=0) 

74 out_offset = cumsum - repeats 

75 

76 tl.device_assert(repeats >= 0, "repeats can not be negative") 

77 

78 out_ptr += out_offset 

79 for start_k in range(0, repeats, BLOCK_SIZE): 

80 offsets_k = start_k + tl.arange(0, BLOCK_SIZE) 

81 mask_k = offsets_k < repeats 

82 tl.store(out_ptr + offsets_k, pid, mask=mask_k) 

83 

84 

85def repeat_interleave_tensor(repeats, *, output_size=None): 

86 logger.debug("GEMS REPEAT_INTERLEAVE_TENSOR") 

87 

88 assert repeats.ndim == 1, "repeat_interleave only accept 1D vector as repeat" 

89 

90 cumsum = repeats.cumsum(axis=0) 

91 result_size = cumsum[-1].item() 

92 

93 assert result_size >= 0, "repeats can not be negative" 

94 

95 out = torch.empty((result_size,), dtype=repeats.dtype, device=repeats.device) 

96 size = repeats.size(0) 

97 

98 grid = (size,) 

99 BLOCK_SIZE = 32 

100 repeat_interleave_tensor_kernel[grid]( 

101 repeats, 

102 cumsum, 

103 out, 

104 size, 

105 BLOCK_SIZE=BLOCK_SIZE, 

106 num_warps=1, 

107 ) 

108 return out 

109 

110 

111def repeat_interleave_self_tensor(inp, repeats, dim=None, *, output_size=None): 

112 logger.debug("GEMS REPEAT_INTERLEAVE_SELF_TENSOR") 

113 

114 if dim is None: 

115 inp = inp.flatten() 

116 dim = 0 

117 else: 

118 if (dim < -inp.ndim) or (dim >= inp.ndim): 

119 raise IndexError( 

120 "Dimension out of range (expected to be in range of [{}, {}], but got {})".format( 

121 -inp.ndim, inp.ndim - 1, dim 

122 ) 

123 ) 

124 

125 if repeats.ndim == 0 or (repeats.ndim == 1 and repeats.size(0) == 1): 

126 return repeat_interleave_self_int( 

127 inp, repeats.item(), dim=dim, output_size=output_size 

128 ) 

129 elif repeats.ndim > 1: 

130 raise RuntimeError("repeats must be 0-dim or 1-dim tensor") 

131 

132 inp_shape = list(inp.shape) 

133 if dim < 0: 

134 dim = dim + len(inp_shape) 

135 

136 if repeats.size(0) != inp_shape[dim]: 

137 raise RuntimeError( 

138 "repeats must have the same size as input along dim, but got \ 

139 repeats.size(0) = {} and input.size({}) = {}".format( 

140 repeats.size(0), dim, inp_shape[dim] 

141 ) 

142 ) 

143 

144 indices = repeat_interleave_tensor(repeats) 

145 res = torch.index_select(inp, dim, indices) 

146 

147 return res