Coverage for src/flag_gems/experimental_ops/masked_select.py: 0%

85 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-16 02:02 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def _masked_select_count_kernel( 

8 mask_ptr, # int32* flattened mask (0/1) 

9 n_elements, # int32 number of elements 

10 counts_ptr, # int32* per-block counts 

11 BLOCK_SIZE: tl.constexpr, 

12): 

13 pid = tl.program_id(axis=0) 

14 block_start = pid * BLOCK_SIZE 

15 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

16 in_bounds = offsets < n_elements 

17 

18 flags = tl.load(mask_ptr + offsets, mask=in_bounds, other=0) # int32 0/1 

19 block_count = tl.sum(flags, axis=0) 

20 tl.store(counts_ptr + pid, block_count) 

21 

22 

23@triton.jit 

24def _masked_select_scatter_kernel( 

25 input_ptr, # * input data (flattened, contiguous) 

26 mask_ptr, # int32* flattened mask (0/1) 

27 block_offsets_ptr, # int32* per-block exclusive offsets 

28 output_ptr, # * output data 

29 n_elements, # int32 number of elements 

30 BLOCK_SIZE: tl.constexpr, 

31): 

32 pid = tl.program_id(axis=0) 

33 block_start = pid * BLOCK_SIZE 

34 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

35 in_bounds = offsets < n_elements 

36 

37 flags = tl.load(mask_ptr + offsets, mask=in_bounds, other=0) # int32 

38 # Compute local exclusive positions for true elements 

39 inclusive = tl.cumsum(flags, axis=0) 

40 local_pos = inclusive - 1 # valid only where flags == 1 

41 

42 base = tl.load(block_offsets_ptr + pid) # int32 

43 write_idx = base + local_pos 

44 

45 mstore = in_bounds & (flags != 0) 

46 vals = tl.load(input_ptr + offsets, mask=mstore, other=0) 

47 tl.store(output_ptr + write_idx, vals, mask=mstore) 

48 

49 

50def _prepare_broadcast_flatten(input: torch.Tensor, mask: torch.Tensor): 

51 # Broadcast input and mask to a common shape 

52 bshape = torch.broadcast_shapes(tuple(input.shape), tuple(mask.shape)) 

53 inp_b = input.expand(bshape) 

54 msk_b = mask.to(torch.bool).expand(bshape) 

55 

56 # Make contiguous flattened views 

57 inp_flat = inp_b.contiguous().view(-1) 

58 msk_flat_bool = msk_b.contiguous().view(-1) 

59 # Convert mask to int32 (0/1) for kernels 

60 msk_flat_i32 = msk_flat_bool.to(torch.int32) 

61 return inp_flat, msk_flat_i32 

62 

63 

64def masked_select(input: torch.Tensor, mask: torch.Tensor): 

65 inp_flat, msk_flat_i32 = _prepare_broadcast_flatten(input, mask) 

66 device = inp_flat.device 

67 assert msk_flat_i32.device == device, "input and mask must be on the same device" 

68 

69 n_elements = inp_flat.numel() 

70 if n_elements == 0: 

71 return torch.empty(0, dtype=input.dtype, device=device) 

72 

73 BLOCK_SIZE = 1024 

74 num_blocks = triton.cdiv(n_elements, BLOCK_SIZE) 

75 

76 counts = torch.empty(num_blocks, dtype=torch.int32, device=device) 

77 grid = (num_blocks,) 

78 _masked_select_count_kernel[grid]( 

79 msk_flat_i32, n_elements, counts, BLOCK_SIZE=BLOCK_SIZE 

80 ) 

81 

82 # Compute per-block exclusive offsets and total number of selected elements 

83 if num_blocks == 1: 

84 block_offsets = torch.zeros(1, dtype=torch.int32, device=device) 

85 total_selected = int(counts[0].item()) 

86 else: 

87 prefix = torch.cumsum(counts, dim=0) 

88 block_offsets = torch.empty_like(counts) 

89 block_offsets[0] = 0 

90 block_offsets[1:] = prefix[:-1] 

91 total_selected = int(prefix[-1].item()) 

92 

93 output = torch.empty(total_selected, dtype=input.dtype, device=device) 

94 _masked_select_scatter_kernel[grid]( 

95 inp_flat, msk_flat_i32, block_offsets, output, n_elements, BLOCK_SIZE=BLOCK_SIZE 

96 ) 

97 return output 

98 

99 

100def masked_select_out(input: torch.Tensor, mask: torch.Tensor, out: torch.Tensor): 

101 inp_flat, msk_flat_i32 = _prepare_broadcast_flatten(input, mask) 

102 device = inp_flat.device 

103 assert msk_flat_i32.device == device, "input and mask must be on the same device" 

104 if out.device != device: 

105 raise RuntimeError("out tensor must be on the same device as input") 

106 

107 n_elements = inp_flat.numel() 

108 if n_elements == 0: 

109 out.resize_(0) 

110 return out 

111 

112 BLOCK_SIZE = 1024 

113 num_blocks = triton.cdiv(n_elements, BLOCK_SIZE) 

114 

115 counts = torch.empty(num_blocks, dtype=torch.int32, device=device) 

116 grid = (num_blocks,) 

117 _masked_select_count_kernel[grid]( 

118 msk_flat_i32, n_elements, counts, BLOCK_SIZE=BLOCK_SIZE 

119 ) 

120 

121 # Compute per-block exclusive offsets and total number of selected elements 

122 if num_blocks == 1: 

123 block_offsets = torch.zeros(1, dtype=torch.int32, device=device) 

124 total_selected = int(counts[0].item()) 

125 else: 

126 prefix = torch.cumsum(counts, dim=0) 

127 block_offsets = torch.empty_like(counts) 

128 block_offsets[0] = 0 

129 block_offsets[1:] = prefix[:-1] 

130 total_selected = int(prefix[-1].item()) 

131 

132 if out.dtype != input.dtype: 

133 raise RuntimeError("out tensor dtype must match input dtype") 

134 out.resize_(total_selected) 

135 

136 _masked_select_scatter_kernel[grid]( 

137 inp_flat, msk_flat_i32, block_offsets, out, n_elements, BLOCK_SIZE=BLOCK_SIZE 

138 ) 

139 return out