Coverage for src/flag_gems/runtime/backend/_cambricon/ops/slice_scatter.py: 0%

51 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-25 02:48 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils.shape_utils import MemOverlap, has_internal_overlapping 

8 

9logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

10 

11 

12@triton.jit 

13def slice_scatter_kernel( 

14 out_ptr, 

15 inp_ptr, 

16 src_ptr, 

17 total_elements, 

18 dim_size, 

19 dim_prod_post, 

20 start, 

21 step, 

22 src_dim_size, 

23 BLOCK_SIZE: tl.constexpr, 

24): 

25 pid = tl.program_id(0) 

26 block_start = pid * BLOCK_SIZE 

27 offsets = tl.arange(0, BLOCK_SIZE) 

28 mask = block_start + offsets < total_elements 

29 

30 idx = block_start + offsets 

31 pre_idx = idx // (dim_size * dim_prod_post) 

32 dim_idx = (idx // dim_prod_post) % dim_size 

33 post_idx = idx % dim_prod_post 

34 

35 slice_mask = ( 

36 (dim_idx >= start) 

37 & (dim_idx < start + src_dim_size * step) 

38 & ((dim_idx - start) % step == 0) 

39 ) 

40 

41 inp_data = tl.load(inp_ptr + idx, mask=mask) 

42 

43 src_dim_idx = (dim_idx - start) // step 

44 src_idx = ( 

45 pre_idx * src_dim_size * dim_prod_post + src_dim_idx * dim_prod_post + post_idx 

46 ) 

47 src_data = tl.load(src_ptr + src_idx, mask=mask & slice_mask) 

48 result = tl.where(slice_mask, src_data, inp_data) 

49 tl.store(out_ptr + idx, result, mask=mask) 

50 

51 

52def slice_scatter(inp, src, dim=0, start=None, end=None, step=1): 

53 logger.debug("GEMS_CAMBRICON SLICE_SCATTER") 

54 assert src.device == inp.device, "inp and src reside on different devices." 

55 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

56 assert step > 0, "slice step must be positive" 

57 dim = dim % inp.ndim 

58 

59 start = start or 0 

60 end = end or inp.size(dim) 

61 if end < 0: 

62 end = end % inp.size(dim) 

63 

64 valid_shape = list(inp.shape) 

65 valid_shape[dim] = triton.cdiv(end - start, step) 

66 assert ( 

67 list(src.shape) == valid_shape 

68 ), "Expected src to have a size equal to the slice of self" 

69 

70 if has_internal_overlapping(inp) == MemOverlap.Yes: 

71 out = torch.empty(inp.size(), dtype=inp.dtype, device=inp.device) 

72 else: 

73 out = torch.empty_strided( 

74 inp.size(), inp.stride(), dtype=inp.dtype, device=inp.device 

75 ) 

76 

77 inp = inp.contiguous() 

78 src = src.contiguous() 

79 

80 total_elements = inp.numel() 

81 dim_size = inp.size(dim) 

82 src_dim_size = src.size(dim) 

83 

84 dim_prod_post = 1 

85 for d in range(dim + 1, inp.ndim): 

86 dim_prod_post *= inp.size(d) 

87 

88 BLOCK_SIZE = 2048 

89 grid = (triton.cdiv(total_elements, BLOCK_SIZE),) 

90 

91 slice_scatter_kernel[grid]( 

92 out, 

93 inp, 

94 src, 

95 total_elements, 

96 dim_size, 

97 dim_prod_post, 

98 start, 

99 step, 

100 src_dim_size, 

101 BLOCK_SIZE=BLOCK_SIZE, 

102 ) 

103 

104 return out