Coverage for src/flag_gems/experimental_ops/slice_scatter.py: 0%

87 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def _copy_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offs = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offs < n_elements 

12 x = tl.load(x_ptr + offs, mask=mask) 

13 tl.store(y_ptr + offs, x, mask=mask) 

14 

15 

16@triton.jit 

17def _slice_scatter_kernel( 

18 src_ptr, # pointer to src tensor (flattened) 

19 out_ptr, # pointer to output tensor (flattened) 

20 outer, # number of chunks before the sliced dimension 

21 dim_size, # size of the sliced dimension in the output 

22 inner, # number of elements after the sliced dimension 

23 start, # start index along the sliced dimension 

24 step, # step along the sliced dimension 

25 m_size, # number of indices along the sliced dimension to scatter (len of slice) 

26 n_src_elements, # total number of elements in src (outer * m_size * inner) 

27 BLOCK_SIZE: tl.constexpr, 

28): 

29 pid = tl.program_id(axis=0) 

30 offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

31 mask = offs < n_src_elements 

32 

33 # Promote to int64 for intermediate index math 

34 offs_i64 = offs.to(tl.int64) 

35 

36 inner64 = tl.full([BLOCK_SIZE], inner, tl.int64) 

37 m_size64 = tl.full([BLOCK_SIZE], m_size, tl.int64) 

38 dim_size64 = tl.full([BLOCK_SIZE], dim_size, tl.int64) 

39 start64 = tl.full([BLOCK_SIZE], start, tl.int64) 

40 step64 = tl.full([BLOCK_SIZE], step, tl.int64) 

41 

42 chunk64 = m_size64 * inner64 

43 o = offs_i64 // chunk64 

44 rem = offs_i64 - o * chunk64 

45 m = rem // inner64 

46 i = rem - m * inner64 

47 

48 dest_d = start64 + m * step64 # index along sliced dimension 

49 dest_linear = o * (dim_size64 * inner64) + dest_d * inner64 + i 

50 

51 val = tl.load(src_ptr + offs, mask=mask) 

52 tl.store(out_ptr + dest_linear.to(tl.int32), val, mask=mask) 

53 

54 

55def _normalize_slice_params(size, start, end, step): 

56 assert step is not None and step != 0, "step must be non-zero" 

57 # This implementation supports only positive step for simplicity 

58 assert step > 0, "Only positive step is supported in this Triton implementation" 

59 if start is None: 

60 start = 0 

61 if end is None: 

62 end = size 

63 if start < 0: 

64 start += size 

65 if end < 0: 

66 end += size 

67 # Clamp to [0, size] 

68 start = max(0, min(start, size)) 

69 end = max(0, min(end, size)) 

70 if end <= start: 

71 m = 0 

72 else: 

73 m = (end - start + step - 1) // step 

74 return start, end, step, m 

75 

76 

77def _slice_scatter_impl(input, src, out, dim=0, start=None, end=None, step=1): 

78 assert ( 

79 input.is_cuda and src.is_cuda and out.is_cuda 

80 ), "All tensors must be CUDA tensors" 

81 assert ( 

82 input.is_contiguous() and src.is_contiguous() and out.is_contiguous() 

83 ), "Tensors must be contiguous" 

84 assert input.dtype == src.dtype == out.dtype, "All tensors must have the same dtype" 

85 assert input.shape == out.shape, "Output must have same shape as input" 

86 

87 ndim = input.dim() 

88 if ndim == 0: 

89 # Scalar case: slice along dim doesn't apply, just copy input to out (and no scatter) 

90 if input.numel() > 0: 

91 n = input.numel() 

92 grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),) 

93 _copy_kernel[grid](input, out, n, BLOCK_SIZE=1024) 

94 return out 

95 

96 dim = dim if dim >= 0 else dim + ndim 

97 assert 0 <= dim < ndim, "dim out of range" 

98 

99 size_d = input.size(dim) 

100 s, e, st, m = _normalize_slice_params(size_d, start, end, step) 

101 

102 # Compute outer and inner sizes 

103 outer = 1 

104 for k in range(0, dim): 

105 outer *= input.size(k) 

106 inner = 1 

107 for k in range(dim + 1, ndim): 

108 inner *= input.size(k) 

109 

110 # Validate src shape/numel 

111 expected_src_numel = outer * m * inner 

112 assert src.numel() == expected_src_numel, ( 

113 f"src numel mismatch: got {src.numel()}, expected {expected_src_numel} " 

114 f"(outer={outer}, m={m}, inner={inner})" 

115 ) 

116 

117 # 1) Copy input -> out 

118 n = input.numel() 

119 if n > 0: 

120 grid_copy = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),) 

121 _copy_kernel[grid_copy](input, out, n, BLOCK_SIZE=1024) 

122 

123 # 2) Scatter src into the sliced region 

124 if expected_src_numel > 0: 

125 grid_scatter = lambda meta: ( 

126 triton.cdiv(expected_src_numel, meta["BLOCK_SIZE"]), 

127 ) 

128 _slice_scatter_kernel[grid_scatter]( 

129 src, 

130 out, 

131 outer, 

132 size_d, 

133 inner, 

134 s, 

135 st, 

136 m, 

137 expected_src_numel, 

138 BLOCK_SIZE=1024, 

139 ) 

140 

141 return out 

142 

143 

144def slice_scatter(input, src, dim=0, start=None, end=None, step=1): 

145 out = torch.empty_like(input) 

146 return _slice_scatter_impl( 

147 input, src, out, dim=dim, start=start, end=end, step=step 

148 ) 

149 

150 

151def slice_scatter_out(input, src, dim=0, start=None, end=None, step=1, out=None): 

152 if out is None: 

153 out = torch.empty_like(input) 

154 return _slice_scatter_impl( 

155 input, src, out, dim=dim, start=start, end=end, step=step 

156 )