Coverage for src/flag_gems/ops/slice_backward.py: 68%

31 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-26 15:32 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def slice_backward_kernel( 

8 grad_output_ptr, 

9 grad_input_ptr, 

10 numel, 

11 inner, 

12 slice_len, 

13 dim_size, 

14 start, 

15 step, 

16 BLOCK_SIZE: tl.constexpr, 

17): 

18 pid = tl.program_id(0) 

19 

20 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

21 

22 mask = offsets < numel 

23 

24 grad = tl.load(grad_output_ptr + offsets, mask=mask) 

25 

26 outer_idx = offsets // (slice_len * inner) 

27 

28 slice_idx = (offsets // inner) % slice_len 

29 

30 inner_idx = offsets % inner 

31 

32 dim_index = start + slice_idx * step 

33 

34 input_offset = outer_idx * dim_size * inner + dim_index * inner + inner_idx 

35 

36 tl.store(grad_input_ptr + input_offset, grad, mask=mask) 

37 

38 

39def slice_backward( 

40 grad_output, 

41 input_sizes, 

42 dim, 

43 start, 

44 end, 

45 step, 

46): 

47 grad_input = torch.zeros( 

48 input_sizes, 

49 device=grad_output.device, 

50 dtype=grad_output.dtype, 

51 ) 

52 

53 shape = list(input_sizes) 

54 

55 slice_len = (end - start + step - 1) // step 

56 

57 outer = 1 

58 for i in range(dim): 

59 outer *= shape[i] 

60 

61 inner = 1 

62 for i in range(dim + 1, len(shape)): 

63 inner *= shape[i] 

64 

65 dim_size = shape[dim] 

66 

67 numel = grad_output.numel() 

68 

69 BLOCK = 1024 

70 

71 grid = lambda meta: (triton.cdiv(numel, meta["BLOCK_SIZE"]),) 

72 

73 slice_backward_kernel[grid]( 

74 grad_output, 

75 grad_input, 

76 numel, 

77 inner, 

78 slice_len, 

79 dim_size, 

80 start, 

81 step, 

82 BLOCK_SIZE=BLOCK, 

83 ) 

84 

85 return grad_input