Coverage for src/flag_gems/runtime/backend/_metax/ops/masked_fill.py: 0%

65 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7import flag_gems.runtime as runtime 

8from flag_gems.utils import broadcastable_to, libentry, libtuner 

9from flag_gems.utils import triton_lang_extension as tle 

10 

11logger = logging.getLogger("flag_gems." + __name__) 

12 

13 

14@libentry() 

15@libtuner(configs=runtime.get_tuned_config("masked_fill"), key=["N"]) 

16@triton.jit 

17def masked_fill_kernel(inp, expand_mask, value, out, N, BLOCK_SIZE: tl.constexpr): 

18 pid = tle.program_id(axis=0) 

19 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

20 mask = offsets < N 

21 

22 fill_mask = tl.load(expand_mask + offsets, mask=mask, other=0).to(tl.int1) 

23 cur_inp = tl.load(inp + offsets, mask=mask, other=0) 

24 cur_inp = tl.where(fill_mask, value, cur_inp) 

25 tl.store(out + offsets, cur_inp, mask) 

26 

27 

28@libentry() 

29@libtuner(configs=runtime.get_tuned_config("masked_fill"), key=["N"]) 

30@triton.jit 

31def masked_fill_kernel_self(inp, expand_mask, value, N, BLOCK_SIZE: tl.constexpr): 

32 pid = tle.program_id(axis=0) 

33 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

34 mask = offsets < N 

35 

36 fill_mask = tl.load(expand_mask + offsets, mask=mask, other=0).to(tl.int1) 

37 tl.store(inp + offsets, value, fill_mask and mask) 

38 

39 

40def masked_fill(inp, mask, value): 

41 logger.debug("METAX GEMS MASKED FILL") 

42 assert ( 

43 (torch.is_tensor(value) and value.ndim == 0) 

44 or isinstance(value, int) 

45 or isinstance(value, float) 

46 ), "masked_fill_ only supports a 0-dimensional value tensor" 

47 if torch.is_tensor(value): 

48 # Value can be a tensor or a scalar 

49 value = value.item() 

50 assert broadcastable_to( 

51 mask.shape, inp.shape 

52 ), "The shape of mask must be broadcastable with the shape of the underlying tensor" 

53 

54 if inp.ndim == 0: 

55 # inp is a single-value 

56 return ( 

57 torch.tensor(value, dtype=inp.dtype, device=inp.device) 

58 if mask.item() 

59 else inp.clone() 

60 ) 

61 

62 inp = inp.contiguous() 

63 mask = mask.contiguous() 

64 expand_mask = mask.expand(inp.shape) 

65 out = torch.empty_like(inp, dtype=inp.dtype, device=inp.device) 

66 

67 N = inp.numel() 

68 if N == 0: 

69 return out 

70 grid = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),) 

71 masked_fill_kernel[grid](inp, expand_mask, value, out, N) 

72 return out 

73 

74 

75def masked_fill_(inp, mask, value): 

76 logger.debug("METAX GEMS MASKED FILL") 

77 assert ( 

78 (torch.is_tensor(value) and value.ndim == 0) 

79 or isinstance(value, int) 

80 or isinstance(value, float) 

81 ), "masked_fill_ only supports a 0-dimensional value tensor" 

82 if torch.is_tensor(value): 

83 # Value can be a tensor or a scalar 

84 value = value.item() 

85 assert broadcastable_to( 

86 mask.shape, inp.shape 

87 ), "The shape of mask must be broadcastable with the shape of the underlying tensor" 

88 

89 if inp.ndim == 0: 

90 # inp is a single-value 

91 if mask.item(): 

92 inp[()] = value 

93 return inp 

94 

95 inp = inp.contiguous() 

96 mask = mask.contiguous() 

97 expand_mask = mask.expand(inp.shape) 

98 

99 N = inp.numel() 

100 if N == 0: 

101 return inp 

102 grid = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),) 

103 masked_fill_kernel_self[grid](inp, expand_mask, value, N) 

104 return inp