Coverage for src/flag_gems/experimental_ops/masked_scatter.py: 0%

75 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-18 02:36 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def _masked_scatter_count_kernel( 

8 mask_ptr, # *Pointer* to mask tensor (bool) 

9 counts_ptr, # *Pointer* to per-block counts (int32) 

10 n_elements, # Number of elements in the flattened input 

11 BLOCK_SIZE: tl.constexpr, 

12): 

13 pid = tl.program_id(axis=0) 

14 block_start = pid * BLOCK_SIZE 

15 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

16 in_bounds = offsets < n_elements 

17 

18 m = tl.load(mask_ptr + offsets, mask=in_bounds, other=0) 

19 m_i32 = m.to(tl.int32) 

20 local_count = tl.sum(m_i32, axis=0) 

21 tl.store(counts_ptr + pid, local_count) 

22 

23 

24@triton.jit 

25def _masked_scatter_apply_kernel( 

26 in_ptr, # *Pointer* to input tensor 

27 mask_ptr, # *Pointer* to mask tensor (bool) 

28 src_ptr, # *Pointer* to source tensor (1D) 

29 out_ptr, # *Pointer* to output tensor 

30 n_elements, # Number of elements in the flattened input 

31 prefix_ptr, # *Pointer* to per-block exclusive prefix sums (int32) 

32 BLOCK_SIZE: tl.constexpr, 

33): 

34 pid = tl.program_id(axis=0) 

35 block_start = pid * BLOCK_SIZE 

36 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

37 in_bounds = offsets < n_elements 

38 

39 x = tl.load(in_ptr + offsets, mask=in_bounds) 

40 m = tl.load(mask_ptr + offsets, mask=in_bounds, other=0) 

41 m_i32 = m.to(tl.int32) 

42 

43 # Compute per-block exclusive ranks for True mask elements 

44 inclusive = tl.cumsum(m_i32, axis=0) 

45 rank = inclusive - m_i32 # exclusive rank within the block 

46 

47 block_offset = tl.load(prefix_ptr + pid, mask=True, other=0).to(rank.dtype) 

48 global_rank = block_offset + rank 

49 

50 take = m_i32 != 0 

51 gathered = tl.load(src_ptr + global_rank, mask=(in_bounds & take), other=0) 

52 

53 out_vals = tl.where(take, gathered, x) 

54 tl.store(out_ptr + offsets, out_vals, mask=in_bounds) 

55 

56 

57def _launch_masked_scatter( 

58 input_tensor: torch.Tensor, 

59 mask: torch.Tensor, 

60 source: torch.Tensor, 

61 out_tensor: torch.Tensor = None, 

62): 

63 # Validate inputs 

64 if input_tensor is None or mask is None or source is None: 

65 raise ValueError("masked_scatter requires input, mask, and source tensors") 

66 

67 if mask.dtype != torch.bool: 

68 mask = mask.to(torch.bool) 

69 

70 if input_tensor.numel() != mask.numel(): 

71 raise ValueError("input and mask must have the same number of elements") 

72 

73 if out_tensor is None: 

74 out = torch.empty_like(input_tensor) 

75 else: 

76 out = out_tensor 

77 if out.shape != input_tensor.shape: 

78 raise ValueError("out tensor must have the same shape as input") 

79 if out.dtype != input_tensor.dtype: 

80 raise ValueError("out tensor must have the same dtype as input") 

81 if out.device != input_tensor.device: 

82 raise ValueError("out tensor must be on the same device as input") 

83 

84 device = input_tensor.device 

85 if not device.type == "cuda": 

86 raise ValueError("Triton kernels require CUDA tensors") 

87 

88 # Flatten to 1D contiguous views 

89 x_flat = input_tensor.contiguous().view(-1) 

90 m_flat = mask.contiguous().view(-1) 

91 s_flat = source.contiguous().view(-1) 

92 out_flat = out.contiguous().view(-1) 

93 

94 n_elements = x_flat.numel() 

95 if n_elements == 0: 

96 # Nothing to do 

97 out.copy_(input_tensor) 

98 return out 

99 

100 BLOCK_SIZE = 1024 

101 n_blocks = triton.cdiv(n_elements, BLOCK_SIZE) 

102 

103 # 1) Count number of True mask elements per block 

104 counts = torch.empty(n_blocks, dtype=torch.int32, device=device) 

105 grid = (n_blocks,) 

106 _masked_scatter_count_kernel[grid]( 

107 m_flat, counts, n_elements, BLOCK_SIZE=BLOCK_SIZE 

108 ) 

109 

110 # 2) Compute exclusive prefix sums of per-block counts 

111 counts_prefix = torch.cumsum(counts, dim=0) 

112 total_true = int(counts_prefix[-1].item()) if n_blocks > 0 else 0 

113 if s_flat.numel() < total_true: 

114 raise ValueError( 

115 f"source has fewer elements ({s_flat.numel()}) than required by mask ({total_true})" 

116 ) 

117 prefix_exclusive = counts_prefix - counts # int32, same device 

118 

119 # 3) Apply masked_scatter using per-block prefix offsets 

120 _masked_scatter_apply_kernel[grid]( 

121 x_flat, 

122 m_flat, 

123 s_flat, 

124 out_flat, 

125 n_elements, 

126 prefix_exclusive, 

127 BLOCK_SIZE=BLOCK_SIZE, 

128 ) 

129 

130 # Reshape already matches; ensure out has the result 

131 if out.data_ptr() != out_flat.data_ptr(): 

132 out.view(-1).copy_(out_flat) 

133 return out 

134 

135 

136def masked_scatter(input: torch.Tensor, mask: torch.Tensor, source: torch.Tensor): 

137 return _launch_masked_scatter(input, mask, source, out_tensor=None) 

138 

139 

140def masked_scatter_out( 

141 input: torch.Tensor, mask: torch.Tensor, source: torch.Tensor, out: torch.Tensor 

142): 

143 return _launch_masked_scatter(input, mask, source, out_tensor=out)