Coverage for src/flag_gems/runtime/backend/_ascend/ops/masked_fill.py: 0%

78 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-21 14:31 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.utils import broadcastable_to, libentry 

9from flag_gems.utils import triton_lang_extension as tle 

10 

11logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

12 

13 

14@libentry() 

15@triton.autotune(configs=runtime.get_tuned_config("masked_fill"), key=["N"]) 

16@triton.jit 

17def masked_fill_kernel( 

18 inp, 

19 expand_mask, 

20 value, 

21 out, 

22 N, 

23 BLOCK_SIZE: tl.constexpr, 

24 BLOCK_SIZE_SUB: tl.constexpr, 

25): 

26 pid = tle.program_id(axis=0) 

27 base_offset = pid * BLOCK_SIZE 

28 

29 # 计算需要处理的总块数 

30 num_sub_blocks = BLOCK_SIZE // BLOCK_SIZE_SUB 

31 

32 # 循环处理每个子块 

33 for sub_block_idx in range(num_sub_blocks): 

34 # 计算当前子块的偏移量 

35 sub_offset = base_offset + sub_block_idx * BLOCK_SIZE_SUB 

36 offsets = sub_offset + tl.arange(0, BLOCK_SIZE_SUB) 

37 mask = offsets < N 

38 

39 # 加载 input 和 mask 

40 input_vals = tl.load(inp + offsets, mask=mask, other=0) 

41 fill_mask_vals = tl.load(expand_mask + offsets, mask=mask, other=0).to(tl.int1) 

42 

43 # 先写入原始输入 

44 tl.store(out + offsets, input_vals, mask=mask) 

45 

46 # 再在需要填充的位置覆盖写入 value 

47 value_to_write = tl.full([BLOCK_SIZE_SUB], value, dtype=input_vals.dtype) 

48 overwrite_vals = tl.where( 

49 fill_mask_vals, value_to_write, tl.load(out + offsets, mask=mask, other=0) 

50 ) 

51 tl.store(out + offsets, overwrite_vals, mask=mask) 

52 

53 

54@libentry() 

55@triton.autotune(configs=runtime.get_tuned_config("masked_fill"), key=["N"]) 

56@triton.jit 

57def masked_fill_kernel_self( 

58 inp, expand_mask, value, N, BLOCK_SIZE: tl.constexpr, BLOCK_SIZE_SUB: tl.constexpr 

59): 

60 pid = tle.program_id(axis=0) 

61 base_offset = pid * BLOCK_SIZE 

62 

63 # 计算需要处理的总块数 

64 num_sub_blocks = BLOCK_SIZE // BLOCK_SIZE_SUB 

65 

66 # 循环处理每个子块 

67 for sub_block_idx in range(num_sub_blocks): 

68 # 计算当前子块的偏移量 

69 sub_offset = base_offset + sub_block_idx * BLOCK_SIZE_SUB 

70 offsets = sub_offset + tl.arange(0, BLOCK_SIZE_SUB) 

71 mask = offsets < N 

72 

73 # 加载 expand_mask 

74 fill_mask = tl.load(expand_mask + offsets, mask=mask, other=0).to(tl.int1) 

75 

76 # 构造写入的值:fill_mask==1 用 value,fill_mask==0 保留原值 

77 orig = tl.load(inp + offsets, mask=mask, other=0) 

78 value_vec = tl.full([BLOCK_SIZE_SUB], value, dtype=orig.dtype) 

79 result = tl.where(fill_mask, value_vec, orig) 

80 

81 # 存储结果 

82 tl.store(inp + offsets, result, mask=mask) 

83 

84 

85def masked_fill(inp, mask, value): 

86 logger.debug("GEMS_ASCEND MASKED FILL") 

87 assert ( 

88 (torch.is_tensor(value) and value.ndim == 0) 

89 or isinstance(value, int) 

90 or isinstance(value, float) 

91 ), "masked_fill_ only supports a 0-dimensional value tensor" 

92 if torch.is_tensor(value): 

93 # Value can be a tensor or a scalar 

94 value = value.item() 

95 assert broadcastable_to( 

96 mask.shape, inp.shape 

97 ), "The shape of mask must be broadcastable with the shape of the underlying tensor" 

98 

99 if inp.ndim == 0: 

100 # inp is a single-value 

101 return ( 

102 torch.tensor(value, dtype=inp.dtype, device=inp.device) 

103 if mask.item() 

104 else inp.clone() 

105 ) 

106 

107 inp = inp.contiguous() 

108 mask = mask.contiguous() 

109 expand_mask = mask.expand(inp.shape) 

110 out = torch.empty_like(inp, dtype=inp.dtype, device=inp.device) 

111 

112 N = inp.numel() 

113 if N == 0: 

114 return out 

115 grid = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),) 

116 masked_fill_kernel[grid](inp, expand_mask.to(torch.int), value, out, N) 

117 return out 

118 

119 

120def masked_fill_(inp, mask, value): 

121 logger.debug("GEMS_ASCEND MASKED FILL_") 

122 assert ( 

123 (torch.is_tensor(value) and value.ndim == 0) 

124 or isinstance(value, int) 

125 or isinstance(value, float) 

126 ), "masked_fill_ only supports a 0-dimensional value tensor" 

127 if torch.is_tensor(value): 

128 # Value can be a tensor or a scalar 

129 value = value.item() 

130 assert broadcastable_to( 

131 mask.shape, inp.shape 

132 ), "The shape of mask must be broadcastable with the shape of the underlying tensor" 

133 

134 if inp.ndim == 0: 

135 # inp is a single-value 

136 if mask.item(): 

137 inp[()] = value 

138 return inp 

139 

140 inp = inp.contiguous() 

141 mask = mask.contiguous() 

142 expand_mask = mask.expand(inp.shape) 

143 

144 N = inp.numel() 

145 if N == 0: 

146 return inp 

147 grid = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),) 

148 masked_fill_kernel_self[grid](inp, expand_mask.to(torch.int), value, N) 

149 return inp