Coverage for src/flag_gems/runtime/backend/_cambricon/ops/masked_fill.py: 0%

73 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-09 01:57 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.utils import broadcastable_to, libentry, libtuner 

9 

10from ..utils import MAX_GRID_SIZE_X 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13 

14 

15@libentry() 

16@libtuner( 

17 configs=runtime.get_tuned_config("masked_fill"), 

18 key=["N"], 

19) 

20@triton.jit 

21def masked_fill_kernel(inp, expand_mask, value, out, N, BLOCK_SIZE: tl.constexpr): 

22 pid = tl.program_id(axis=0) 

23 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

24 mask = offsets < N 

25 

26 fill_mask = tl.load(expand_mask + offsets, mask=mask, other=0).to(tl.int1) 

27 cur_inp = tl.load(inp + offsets, mask=(not fill_mask) and mask, other=0) 

28 tl.store(out + offsets, cur_inp, (not fill_mask) and mask) 

29 tl.store(out + offsets, value, fill_mask and mask) 

30 

31 

32@libentry() 

33@libtuner( 

34 configs=runtime.get_tuned_config("masked_fill"), 

35 key=["N"], 

36) 

37@triton.jit 

38def masked_fill_kernel_self(inp, expand_mask, value, N, BLOCK_SIZE: tl.constexpr): 

39 num_programs = tl.num_programs(0) 

40 pid = tl.program_id(axis=0) 

41 total_blocks = tl.cdiv(N, BLOCK_SIZE) 

42 

43 for block_idx in range(pid, total_blocks, num_programs): 

44 offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

45 mask = offsets < N 

46 

47 fill_mask = tl.load(expand_mask + offsets, mask=mask, other=0).to(tl.int1) 

48 cur_val = tl.full((BLOCK_SIZE,), value, dtype=inp.dtype.element_ty) 

49 tl.store(inp + offsets, cur_val, fill_mask and mask) 

50 

51 

52def masked_fill(inp, mask, value): 

53 logger.debug("GEMS_CAMBRICON MASKED FILL") 

54 assert ( 

55 (torch.is_tensor(value) and value.ndim == 0) 

56 or isinstance(value, int) 

57 or isinstance(value, float) 

58 ), "masked_fill_ only supports a 0-dimensional value tensor" 

59 if torch.is_tensor(value): 

60 # Value can be a tensor or a scalar 

61 value = value.item() 

62 assert broadcastable_to( 

63 mask.shape, inp.shape 

64 ), "The shape of mask must be broadcastable with the shape of the underlying tensor" 

65 

66 if inp.ndim == 0: 

67 # inp is a single-value 

68 return ( 

69 torch.tensor(value, dtype=inp.dtype, device=inp.device) 

70 if mask.item() 

71 else inp.clone() 

72 ) 

73 

74 inp = inp.contiguous() 

75 mask = mask.contiguous() 

76 expand_mask = mask.expand(inp.shape) 

77 out = torch.empty_like(inp, dtype=inp.dtype, device=inp.device) 

78 

79 N = inp.numel() 

80 if N == 0: 

81 return out 

82 

83 def gridfn(meta): 

84 blocks = triton.cdiv(N, meta["BLOCK_SIZE"]) 

85 x = min(MAX_GRID_SIZE_X, blocks) 

86 y = triton.cdiv(blocks, x) 

87 return (x, y, 1) 

88 

89 masked_fill_kernel[gridfn](inp, expand_mask.to(torch.int), value, out, N) 

90 return out 

91 

92 

93def masked_fill_(inp, mask, value): 

94 logger.debug("GEMS_CAMBRICON MASKED FILL") 

95 assert ( 

96 (torch.is_tensor(value) and value.ndim == 0) 

97 or isinstance(value, int) 

98 or isinstance(value, float) 

99 ), "masked_fill_ only supports a 0-dimensional value tensor" 

100 if torch.is_tensor(value): 

101 # Value can be a tensor or a scalar 

102 value = value.item() 

103 assert broadcastable_to( 

104 mask.shape, inp.shape 

105 ), "The shape of mask must be broadcastable with the shape of the underlying tensor" 

106 

107 if inp.ndim == 0: 

108 # inp is a single-value 

109 if mask.item(): 

110 inp[()] = value 

111 return inp 

112 

113 inp = inp.contiguous() 

114 mask = mask.contiguous() 

115 expand_mask = mask.expand(inp.shape) 

116 

117 N = inp.numel() 

118 if N == 0: 

119 return inp 

120 grid = lambda meta: (min(65535, triton.cdiv(N, meta["BLOCK_SIZE"])),) 

121 masked_fill_kernel_self[grid](inp, expand_mask.to(torch.int), value, N) 

122 return inp