Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/masked_fill.py: 0%

84 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import broadcastable_to, libentry 

8from flag_gems.utils import triton_lang_extension as tle 

9 

10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

11 

12 

13def masked_fill_kernel_heur_block_size(args): 

14 return triton.next_power_of_2(triton.cdiv(args["N"], 12)) # cluster_num 

15 

16 

17@libentry() 

18# @triton.autotune(configs=runtime.get_tuned_config("masked_fill"), key=["N"]) 

19# @triton.heuristics( 

20# values={ 

21# "BLOCK_SIZE": masked_fill_kernel_heur_block_size, 

22# }, 

23# ) 

24@triton.jit 

25def masked_fill_kernel( 

26 inp, expand_mask, value, out, N: tl.constexpr, BLOCK_SIZE: tl.constexpr 

27): 

28 pid = tle.program_id(axis=0) 

29 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

30 mask = offsets < N 

31 

32 fill_mask = tl.load(expand_mask + offsets, mask=mask, other=0).to(tl.int1) 

33 cur_inp = tl.load(inp + offsets, mask=(not fill_mask) and mask, other=0) 

34 out_offset_1 = tl.where((not fill_mask) and mask, offsets, -1) 

35 tl.store(out + out_offset_1, cur_inp, (not fill_mask) and mask) 

36 out_offset_2 = tl.where(fill_mask and mask, offsets, -1) 

37 tl.store(out + out_offset_2, value, fill_mask and mask) 

38 

39 

40def masked_fill_kernel_self_heur_block_size(args): 

41 return triton.next_power_of_2(triton.cdiv(args["N"], 12)) # cluster_num 

42 

43 

44@libentry() 

45# @triton.autotune(configs=runtime.get_tuned_config("masked_fill"), key=["N"]) 

46# @triton.heuristics( 

47# values={ 

48# "BLOCK_SIZE": masked_fill_kernel_self_heur_block_size, 

49# }, 

50# ) 

51@triton.jit 

52def masked_fill_kernel_self(inp, expand_mask, value, N, BLOCK_SIZE: tl.constexpr): 

53 pid = tle.program_id(axis=0) 

54 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

55 mask = offsets < N 

56 

57 fill_mask = tl.load(expand_mask + offsets, mask=mask, other=0).to(tl.int1) 

58 tl.store(inp + offsets, value, fill_mask and mask) 

59 

60 

61def masked_fill(inp, mask, value): 

62 logger.debug("GEMS MASKED FILL") 

63 assert ( 

64 (torch.is_tensor(value) and value.ndim == 0) 

65 or isinstance(value, int) 

66 or isinstance(value, float) 

67 ), "masked_fill_ only supports a 0-dimensional value tensor" 

68 if torch.is_tensor(value): 

69 # Value can be a tensor or a scalar 

70 value = value.item() 

71 assert broadcastable_to( 

72 mask.shape, inp.shape 

73 ), "The shape of mask must be broadcastable with the shape of the underlying tensor" 

74 

75 if inp.ndim == 0: 

76 # inp is a single-value 

77 return ( 

78 torch.tensor(value, dtype=inp.dtype, device=inp.device) 

79 if mask.item() 

80 else inp.clone() 

81 ) 

82 

83 inp = inp.contiguous() 

84 mask = mask.contiguous() 

85 expand_mask = mask.expand(inp.shape) 

86 out = torch.empty_like(inp, dtype=inp.dtype, device=inp.device) 

87 

88 N = inp.numel() 

89 if N == 0: 

90 return out 

91 grid = 12 

92 BLOCK_SIZE = triton.next_power_of_2(triton.cdiv(N, grid)) 

93 

94 import os 

95 

96 os.environ["TRITONXPU_OTHER_SIM"] = "1" 

97 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1" 

98 masked_fill_kernel[grid,]( 

99 inp, 

100 expand_mask.to(torch.int), 

101 value, 

102 out, 

103 N, 

104 BLOCK_SIZE, 

105 isCloseUnrollControl=True, 

106 buffer_size_limit=2048, 

107 ) 

108 

109 if "TRITONXPU_OTHER_SIM" in os.environ: 

110 del os.environ["TRITONXPU_OTHER_SIM"] 

111 if "TRITONXPU_STORE_MASK_SIM" in os.environ: 

112 del os.environ["TRITONXPU_STORE_MASK_SIM"] 

113 return out 

114 

115 

116def masked_fill_(inp, mask, value): 

117 logger.debug("GEMS MASKED FILL") 

118 assert ( 

119 (torch.is_tensor(value) and value.ndim == 0) 

120 or isinstance(value, int) 

121 or isinstance(value, float) 

122 ), "masked_fill_ only supports a 0-dimensional value tensor" 

123 if torch.is_tensor(value): 

124 # Value can be a tensor or a scalar 

125 value = value.item() 

126 assert broadcastable_to( 

127 mask.shape, inp.shape 

128 ), "The shape of mask must be broadcastable with the shape of the underlying tensor" 

129 

130 if inp.ndim == 0: 

131 # inp is a single-value 

132 if mask.item(): 

133 inp[()] = value 

134 return inp 

135 

136 inp = inp.contiguous() 

137 mask = mask.contiguous() 

138 expand_mask = mask.expand(inp.shape) 

139 

140 N = inp.numel() 

141 if N == 0: 

142 return inp 

143 

144 import os 

145 

146 os.environ["TRITONXPU_OTHER_SIM"] = "1" 

147 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1" 

148 

149 grid = 12 

150 BLOCK_SIZE = triton.next_power_of_2(triton.cdiv(N, grid)) 

151 masked_fill_kernel_self[grid,]( 

152 inp, expand_mask.to(torch.int), value, N, BLOCK_SIZE, buffer_size_limit=2048 

153 ) 

154 if "TRITONXPU_OTHER_SIM" in os.environ: 

155 del os.environ["TRITONXPU_OTHER_SIM"] 

156 if "TRITONXPU_STORE_MASK_SIM" in os.environ: 

157 del os.environ["TRITONXPU_STORE_MASK_SIM"] 

158 return inp