Coverage for src/flag_gems/experimental_ops/slice_backward.py: 0%

76 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-17 02:35 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def _slice_backward_scatter_kernel( 

8 grad_ptr, # *Pointer* to grad (input) vector 

9 out_ptr, # *Pointer* to output (full grad) vector 

10 n_elements, # numel of grad 

11 inner, # product of sizes after 'dim' 

12 gdim, # size of grad along 'dim' 

13 odim, # size of output along 'dim' 

14 start, # normalized start index along 'dim' 

15 step, # step along 'dim' 

16 BLOCK_SIZE: tl.constexpr, 

17): 

18 pid = tl.program_id(axis=0) 

19 block_start = pid * BLOCK_SIZE 

20 offs = block_start + tl.arange(0, BLOCK_SIZE) 

21 offs = offs.to(tl.int64) 

22 mask = offs < n_elements 

23 

24 inner_i64 = tl.full([1], inner, tl.int64) 

25 gdim_i64 = tl.full([1], gdim, tl.int64) 

26 odim_i64 = tl.full([1], odim, tl.int64) 

27 start_i64 = tl.full([1], start, tl.int64) 

28 step_i64 = tl.full([1], step, tl.int64) 

29 

30 # Decompose linear index into (outer, g_idx_dim, inner_idx) 

31 outer = offs // (gdim_i64 * inner_i64) 

32 inner_idx = offs % inner_i64 

33 g_idx_dim = (offs // inner_i64) % gdim_i64 

34 

35 out_dim_index = start_i64 + g_idx_dim * step_i64 

36 valid_o = (out_dim_index >= 0) & (out_dim_index < odim_i64) 

37 o = outer * (odim_i64 * inner_i64) + out_dim_index * inner_i64 + inner_idx 

38 

39 m = mask & valid_o 

40 val = tl.load(grad_ptr + offs, mask=m, other=0) 

41 tl.store(out_ptr + o, val, mask=m) 

42 

43 

44def _normalize_slice_params(input_sizes, dim, start, end, step): 

45 D = len(input_sizes) 

46 if dim < 0: 

47 dim += D 

48 size_dim = int(input_sizes[dim]) 

49 

50 if step is None: 

51 step = 1 

52 if step == 0: 

53 raise ValueError("slice step cannot be zero") 

54 

55 if start is None: 

56 start = 0 if step > 0 else size_dim - 1 

57 

58 # Normalize start into valid index range 

59 if start < 0: 

60 start += size_dim 

61 

62 if step > 0: 

63 # Clamp into [0, size_dim] 

64 if start < 0: 

65 start = 0 

66 if start > size_dim: 

67 start = size_dim 

68 else: 

69 # Clamp into [0, size_dim-1] 

70 if start < 0: 

71 start = 0 

72 if start >= size_dim: 

73 start = size_dim - 1 

74 

75 return dim, int(start), int(step) 

76 

77 

78def _launch_slice_backward_kernel(grad, input_sizes, dim, start, end, step, out): 

79 # Ensure contiguous 

80 grad_c = grad.contiguous() 

81 out_c = out.contiguous() 

82 

83 # Normalize parameters 

84 dim, start_n, step_n = _normalize_slice_params( 

85 list(input_sizes), int(dim), start, end, step 

86 ) 

87 

88 # Compute inner, gdim, odim 

89 sizes = list(input_sizes) 

90 odim = int(sizes[dim]) 

91 gdim = int(grad_c.shape[dim]) 

92 inner = 1 

93 for s in sizes[dim + 1 :]: 

94 inner *= int(s) 

95 

96 n_elements = grad_c.numel() 

97 if n_elements == 0: 

98 return out # nothing to do 

99 

100 # Zero the output tensor 

101 out_c.zero_() 

102 

103 # Launch Triton kernel 

104 BLOCK_SIZE = 1024 

105 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

106 _slice_backward_scatter_kernel[grid]( 

107 grad_c, 

108 out_c, 

109 n_elements, 

110 inner, 

111 gdim, 

112 odim, 

113 start_n, 

114 step_n, 

115 BLOCK_SIZE=BLOCK_SIZE, 

116 ) 

117 return out_c 

118 

119 

120def slice_backward(grad, input_sizes, dim, start, end, step): 

121 """ 

122 Python wrapper for aten::slice_backward 

123 """ 

124 out = torch.empty(tuple(input_sizes), device=grad.device, dtype=grad.dtype) 

125 out = _launch_slice_backward_kernel(grad, input_sizes, dim, start, end, step, out) 

126 return out 

127 

128 

129def slice_backward_out(grad, input_sizes, dim, start, end, step, out): 

130 """ 

131 Python wrapper for aten::slice_backward.out 

132 """ 

133 if tuple(out.shape) != tuple(input_sizes): 

134 raise ValueError("Output tensor shape must match input_sizes") 

135 if out.device != grad.device or out.dtype != grad.dtype: 

136 raise ValueError("Output tensor must have same device and dtype as grad") 

137 _launch_slice_backward_kernel(grad, input_sizes, dim, start, end, step, out) 

138 return out