Coverage for src/flag_gems/runtime/backend/_cambricon/ops/repeat_interleave.py: 0%

111 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-16 02:02 +0800

1import logging 

2 

3import torch 

4import triton 

5from triton import language as tl 

6 

7from flag_gems.utils.pointwise_dynamic import pointwise_dynamic 

8from flag_gems.utils.shape_utils import c_contiguous_stride 

9from flag_gems.utils.tensor_wrapper import StridedBuffer 

10 

11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

12 

13 

14@pointwise_dynamic(num_inputs=1, promotion_methods=[(0, "DEFAULT")]) 

15@triton.jit 

16def copy_func(x): 

17 return x 

18 

19 

20def repeat_interleave_self_int_forward(inp, repeats, dim=None, *, output_size=None): 

21 inp_shape = list(inp.shape) 

22 inp_stride = list(inp.stride()) 

23 output_shape = list(inp.shape) 

24 

25 output_shape[dim] *= repeats 

26 

27 if output_size is not None and output_size != output_shape[dim]: 

28 raise RuntimeError( 

29 "repeat_interleave: Invalid output_size, expected {} but got {}".format( 

30 output_shape[dim], output_size 

31 ) 

32 ) 

33 

34 output = torch.empty(output_shape, dtype=inp.dtype, device=inp.device) 

35 

36 if repeats == 0: 

37 return output 

38 

39 in_view_stride = inp_stride[: dim + 1] + [0] + inp_stride[dim + 1 :] 

40 out_view_shape = inp_shape[: dim + 1] + [repeats] + inp_shape[dim + 1 :] 

41 out_view_stride = c_contiguous_stride(out_view_shape) 

42 

43 in_view = StridedBuffer(inp, out_view_shape, in_view_stride) 

44 out_view = StridedBuffer(output, out_view_shape, out_view_stride) 

45 ndim = len(out_view_shape) 

46 copy_func.instantiate(ndim)(in_view, out0=out_view) 

47 return output 

48 

49 

50class RepeatInterleaveSelfIntFn(torch.autograd.Function): 

51 @staticmethod 

52 def forward(ctx, inp, repeats, dim, output_size): 

53 logger.debug("GEMS_CAMBRICON REPEAT_INTERLEAVE_SELF_INT FORWARD") 

54 ctx.inp_shape = inp.shape 

55 ctx.dim = dim 

56 if dim is None: 

57 inp = inp.flatten() 

58 dim = 0 

59 else: 

60 if (dim < -inp.ndim) or (dim >= inp.ndim): 

61 raise IndexError( 

62 "Dimension out of range (expected to be in range of [{}, {}], but got {})".format( 

63 -inp.ndim, inp.ndim - 1, dim 

64 ) 

65 ) 

66 inp_shape = list(inp.shape) 

67 if dim < 0: 

68 dim = dim + len(inp_shape) 

69 ctx.repeats = repeats 

70 ctx.output_size = output_size 

71 

72 out = repeat_interleave_self_int_forward( 

73 inp, repeats, dim=dim, output_size=output_size 

74 ) 

75 return out 

76 

77 @staticmethod 

78 def backward(ctx, grad_out): 

79 logger.debug("GEMS_CAMBRICON REPEAT_INTERLEAVE_SELF_INT BACKWARD") 

80 dim = ctx.dim 

81 k = ctx.repeats 

82 shape = ctx.inp_shape 

83 new_shape = list(shape) 

84 if ctx.dim is None: 

85 new_shape.insert(len(shape), k) 

86 grad_view = grad_out.view(*new_shape) 

87 grad_x = grad_view.sum(dim=len(shape)) 

88 else: 

89 if dim < 0: 

90 dim = dim + len(shape) 

91 new_shape.insert(dim + 1, k) 

92 grad_view = grad_out.view(*new_shape) 

93 grad_x = grad_view.sum(dim=dim + 1) 

94 return grad_x, None, None, None 

95 

96 

97def repeat_interleave_self_int(inp, repeats, dim=None, *, output_size=None): 

98 return RepeatInterleaveSelfIntFn.apply(inp, repeats, dim, output_size) 

99 

100 

101@triton.jit 

102def repeat_interleave_tensor_kernel( 

103 repeats_ptr, cumsum_ptr, out_ptr, size, BLOCK_SIZE: tl.constexpr 

104): 

105 pid = tl.program_id(0) 

106 mask = pid < size 

107 cumsum = tl.load(cumsum_ptr + pid, mask, other=0) 

108 repeats = tl.load(repeats_ptr + pid, mask, other=0) 

109 out_offset = cumsum - repeats 

110 

111 tl.device_assert(repeats >= 0, "repeats can not be negative") 

112 

113 out_ptr += out_offset 

114 for start_k in range(0, repeats, BLOCK_SIZE): 

115 offsets_k = start_k + tl.arange(0, BLOCK_SIZE) 

116 mask_k = offsets_k < repeats 

117 tl.store(out_ptr + offsets_k, pid, mask=mask_k) 

118 

119 

120def repeat_interleave_tensor(repeats, *, output_size=None): 

121 logger.debug("GEMS_CAMBRICON REPEAT_INTERLEAVE_TENSOR") 

122 

123 assert repeats.ndim == 1, "repeat_interleave only accept 1D vector as repeat" 

124 

125 cumsum = repeats.cumsum(axis=0) 

126 result_size = cumsum[-1].item() 

127 

128 assert result_size >= 0, "repeats can not be negative" 

129 

130 out = torch.empty((result_size,), dtype=repeats.dtype, device=repeats.device) 

131 size = repeats.size(0) 

132 

133 grid = (size,) 

134 BLOCK_SIZE = 32 

135 repeat_interleave_tensor_kernel[grid]( 

136 repeats, 

137 cumsum, 

138 out, 

139 size, 

140 BLOCK_SIZE=BLOCK_SIZE, 

141 num_warps=1, 

142 ) 

143 return out 

144 

145 

146def repeat_interleave_self_tensor(inp, repeats, dim=None, *, output_size=None): 

147 logger.debug("GEMS_CAMBRICON REPEAT_INTERLEAVE_SELF_TENSOR") 

148 

149 if dim is None: 

150 inp = inp.flatten() 

151 dim = 0 

152 else: 

153 if (dim < -inp.ndim) or (dim >= inp.ndim): 

154 raise IndexError( 

155 "Dimension out of range (expected to be in range of [{}, {}], but got {})".format( 

156 -inp.ndim, inp.ndim - 1, dim 

157 ) 

158 ) 

159 

160 if repeats.ndim == 0 or (repeats.ndim == 1 and repeats.size(0) == 1): 

161 return repeat_interleave_self_int( 

162 inp, repeats.item(), dim=dim, output_size=output_size 

163 ) 

164 elif repeats.ndim > 1: 

165 raise RuntimeError("repeats must be 0-dim or 1-dim tensor") 

166 

167 inp_shape = list(inp.shape) 

168 if dim < 0: 

169 dim = dim + len(inp_shape) 

170 

171 if repeats.size(0) != inp_shape[dim]: 

172 raise RuntimeError( 

173 "repeats must have the same size as input along dim, but got \ 

174 repeats.size(0) = {} and input.size({}) = {}".format( 

175 repeats.size(0), dim, inp_shape[dim] 

176 ) 

177 ) 

178 

179 indices = repeat_interleave_tensor(repeats) 

180 res = torch.index_select(inp, dim, indices) 

181 

182 return res