Coverage for src/flag_gems/runtime/backend/_ascend/ops/repeat_interleave.py: 0%

39 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-20 02:31 +0800

1import logging 

2 

3import torch 

4import triton 

5 

6from flag_gems.utils.codegen_config_utils import CodeGenConfig 

7from flag_gems.utils.pointwise_dynamic import pointwise_dynamic 

8from flag_gems.utils.shape_utils import c_contiguous_stride 

9from flag_gems.utils.tensor_wrapper import StridedBuffer 

10 

11logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

12 

13 

14config_ = CodeGenConfig( 

15 2048, 

16 (48, 1, 1), 

17 32, 

18 False, 

19 prefer_1d_tile=int(triton.__version__[0]) < 3, 

20) 

21 

22 

23@pointwise_dynamic(num_inputs=1, promotion_methods=[(0, "DEFAULT")], config=config_) 

24@triton.jit 

25def copy_func(x): 

26 return x 

27 

28 

29def repeat_interleave_self_int(inp, repeats, dim=None, *, output_size=None): 

30 logger.debug("GEMS_ASCEND REPEAT_INTERLEAVE_SELF_INT") 

31 if dim is None: 

32 inp = inp.flatten() 

33 dim = 0 

34 else: 

35 if (dim < -inp.ndim) or (dim >= inp.ndim): 

36 raise IndexError( 

37 "Dimension out of range (expected to be in range of [{}, {}], but got {})".format( 

38 -inp.ndim, inp.ndim - 1, dim 

39 ) 

40 ) 

41 inp_shape = list(inp.shape) 

42 inp_stride = list(inp.stride()) 

43 output_shape = list(inp.shape) 

44 

45 if dim < 0: 

46 dim = dim + len(inp_shape) 

47 

48 output_shape[dim] *= repeats 

49 

50 if output_size is not None and output_size != output_shape[dim]: 

51 raise RuntimeError( 

52 "repeat_interleave: Invalid output_size, expected {} but got {}".format( 

53 output_shape[dim], output_size 

54 ) 

55 ) 

56 

57 output = torch.empty(output_shape, dtype=inp.dtype, device=inp.device) 

58 

59 if repeats == 0: 

60 return output 

61 

62 in_view_stride = inp_stride[: dim + 1] + [0] + inp_stride[dim + 1 :] 

63 out_view_shape = inp_shape[: dim + 1] + [repeats] + inp_shape[dim + 1 :] 

64 out_view_stride = c_contiguous_stride(out_view_shape) 

65 

66 in_view = StridedBuffer(inp, out_view_shape, in_view_stride) 

67 out_view = StridedBuffer(output, out_view_shape, out_view_stride) 

68 ndim = len(out_view_shape) 

69 copy_func.instantiate(ndim)(in_view, out0=out_view) 

70 return output