Coverage for src/flag_gems/ops/select_backward.py: 59%

49 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1import math 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7 

8@triton.jit 

9def _select_backward_kernel( 

10 grad_ptr, 

11 out_ptr, 

12 outer_size, 

13 inner_size, 

14 dim_stride, 

15 index, 

16 BLOCK: tl.constexpr, 

17): 

18 pid = tl.program_id(0) 

19 

20 offs = pid * BLOCK + tl.arange(0, BLOCK) 

21 total = outer_size * inner_size 

22 

23 mask = offs < total 

24 

25 outer = offs // inner_size 

26 inner = offs % inner_size 

27 

28 grad_vals = tl.load(grad_ptr + outer * inner_size + inner, mask=mask) 

29 

30 out_offset = outer * dim_stride + index * inner_size + inner 

31 

32 tl.store(out_ptr + out_offset, grad_vals, mask=mask) 

33 

34 

35def _launch_select_backward(grad, input_sizes, dim, index, out=None): 

36 if not grad.is_cuda: 

37 raise ValueError("grad must be CUDA tensor") 

38 

39 dim = int(dim) 

40 index = int(index) 

41 

42 sizes = list(input_sizes) 

43 ndim = len(sizes) 

44 

45 if dim < 0: 

46 dim += ndim 

47 

48 if dim < 0 or dim >= ndim: 

49 raise ValueError("invalid dim") 

50 

51 dim_size = sizes[dim] 

52 

53 if index < 0 or index >= dim_size: 

54 raise ValueError("index out of range") 

55 

56 outer_size = math.prod(sizes[:dim]) if dim > 0 else 1 

57 inner_size = math.prod(sizes[dim + 1 :]) if dim < ndim - 1 else 1 

58 

59 grad_view = grad.contiguous().view(outer_size, inner_size) 

60 

61 if out is None: 

62 out = torch.zeros( 

63 sizes, 

64 dtype=grad.dtype, 

65 device=grad.device, 

66 ) 

67 else: 

68 if tuple(out.shape) != tuple(sizes): 

69 raise ValueError("out shape mismatch") 

70 if out.dtype != grad.dtype: 

71 raise ValueError("dtype mismatch") 

72 if out.device != grad.device: 

73 raise ValueError("device mismatch") 

74 

75 out.zero_() 

76 

77 dim_stride = dim_size * inner_size 

78 

79 BLOCK = 1024 

80 n_elements = outer_size * inner_size 

81 grid = (triton.cdiv(n_elements, BLOCK),) 

82 

83 _select_backward_kernel[grid]( 

84 grad_view, 

85 out, 

86 outer_size, 

87 inner_size, 

88 dim_stride, 

89 index, 

90 BLOCK=BLOCK, 

91 ) 

92 

93 return out 

94 

95 

96def select_backward(grad, input_sizes, dim, index, out=None): 

97 return _launch_select_backward(grad, input_sizes, dim, index, out=out)