Coverage for src/flag_gems/ops/repeat_interleave.py: 74%

85 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-25 02:48 +0800

1import logging 

2 

3import torch 

4import triton 

5from triton import language as tl 

6 

7from flag_gems.utils import triton_lang_extension as tle 

8from flag_gems.utils.pointwise_dynamic import pointwise_dynamic 

9from flag_gems.utils.shape_utils import c_contiguous_stride 

10from flag_gems.utils.tensor_wrapper import StridedBuffer 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15@pointwise_dynamic(num_inputs=1, promotion_methods=[(0, "DEFAULT")]) 

16@triton.jit 

17def copy_func(x): 

18 return x 

19 

20 

21def repeat_interleave_self_int(inp, repeats, dim=None, *, output_size=None): 

22 logger.debug("GEMS REPEAT_INTERLEAVE_SELF_INT") 

23 if dim is None: 

24 inp = inp.flatten() 

25 dim = 0 

26 else: 

27 if (dim < -inp.ndim) or (dim >= inp.ndim): 

28 raise IndexError( 

29 "Dimension out of range (expected to be in range of [{}, {}], but got {})".format( 

30 -inp.ndim, inp.ndim - 1, dim 

31 ) 

32 ) 

33 inp_shape = list(inp.shape) 

34 inp_stride = list(inp.stride()) 

35 output_shape = list(inp.shape) 

36 

37 if dim < 0: 

38 dim = dim + len(inp_shape) 

39 

40 output_shape[dim] *= repeats 

41 

42 if output_size is not None and output_size != output_shape[dim]: 

43 raise RuntimeError( 

44 "repeat_interleave: Invalid output_size, expected {} but got {}".format( 

45 output_shape[dim], output_size 

46 ) 

47 ) 

48 

49 output = torch.empty(output_shape, dtype=inp.dtype, device=inp.device) 

50 

51 if repeats == 0: 

52 return output 

53 

54 in_view_stride = inp_stride[: dim + 1] + [0] + inp_stride[dim + 1 :] 

55 out_view_shape = inp_shape[: dim + 1] + [repeats] + inp_shape[dim + 1 :] 

56 out_view_stride = c_contiguous_stride(out_view_shape) 

57 

58 in_view = StridedBuffer(inp, out_view_shape, in_view_stride) 

59 out_view = StridedBuffer(output, out_view_shape, out_view_stride) 

60 ndim = len(out_view_shape) 

61 copy_func.instantiate(ndim)(in_view, out0=out_view) 

62 return output 

63 

64 

65@triton.jit 

66def repeat_interleave_tensor_kernel( 

67 repeats_ptr, cumsum_ptr, out_ptr, size, BLOCK_SIZE: tl.constexpr 

68): 

69 pid = tle.program_id(0) 

70 mask = pid < size 

71 cumsum = tl.load(cumsum_ptr + pid, mask, other=0) 

72 repeats = tl.load(repeats_ptr + pid, mask, other=0) 

73 out_offset = cumsum - repeats 

74 

75 tl.device_assert(repeats >= 0, "repeats can not be negative") 

76 

77 out_ptr += out_offset 

78 for start_k in range(0, repeats, BLOCK_SIZE): 

79 offsets_k = start_k + tl.arange(0, BLOCK_SIZE) 

80 mask_k = offsets_k < repeats 

81 tl.store(out_ptr + offsets_k, pid, mask=mask_k) 

82 

83 

84def repeat_interleave_tensor(repeats, *, output_size=None): 

85 logger.debug("GEMS REPEAT_INTERLEAVE_TENSOR") 

86 

87 assert repeats.ndim == 1, "repeat_interleave only accept 1D vector as repeat" 

88 

89 cumsum = repeats.cumsum(axis=0) 

90 result_size = cumsum[-1].item() 

91 

92 assert result_size >= 0, "repeats can not be negative" 

93 

94 out = torch.empty((result_size,), dtype=repeats.dtype, device=repeats.device) 

95 size = repeats.size(0) 

96 

97 grid = (size,) 

98 BLOCK_SIZE = 32 

99 repeat_interleave_tensor_kernel[grid]( 

100 repeats, 

101 cumsum, 

102 out, 

103 size, 

104 BLOCK_SIZE=BLOCK_SIZE, 

105 num_warps=1, 

106 ) 

107 return out 

108 

109 

110def repeat_interleave_self_tensor(inp, repeats, dim=None, *, output_size=None): 

111 logger.debug("GEMS REPEAT_INTERLEAVE_SELF_TENSOR") 

112 

113 if repeats.numel() == 0: 

114 return inp.clone() 

115 

116 if dim is None: 

117 inp = inp.flatten() 

118 dim = 0 

119 else: 

120 if (dim < -inp.ndim) or (dim >= inp.ndim): 

121 raise IndexError( 

122 "Dimension out of range (expected to be in range of [{}, {}], but got {})".format( 

123 -inp.ndim, inp.ndim - 1, dim 

124 ) 

125 ) 

126 

127 if repeats.ndim == 0 or (repeats.ndim == 1 and repeats.size(0) == 1): 

128 return repeat_interleave_self_int( 

129 inp, repeats.item(), dim=dim, output_size=output_size 

130 ) 

131 elif repeats.ndim > 1: 

132 raise RuntimeError("repeats must be 0-dim or 1-dim tensor") 

133 

134 inp_shape = list(inp.shape) 

135 if dim < 0: 

136 dim = dim + len(inp_shape) 

137 

138 if repeats.size(0) != inp_shape[dim]: 

139 raise RuntimeError( 

140 "repeats must have the same size as input along dim, but got \ 

141 repeats.size(0) = {} and input.size({}) = {}".format( 

142 repeats.size(0), dim, inp_shape[dim] 

143 ) 

144 ) 

145 

146 indices = repeat_interleave_tensor(repeats) 

147 res = torch.index_select(inp, dim, indices) 

148 

149 return res