Coverage for src/flag_gems/ops/masked_fill.py: 94%

33 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import broadcastable_to, pointwise_dynamic 

8 

9logger = logging.getLogger(__name__) 

10 

11 

12@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, "NO_OPMATH")]) 

13@triton.jit 

14def masked_fill_kernel(inp, expand_mask, value): 

15 inp = tl.where(expand_mask == 1, value, inp) 

16 return inp 

17 

18 

19def masked_fill(inp, mask, value): 

20 logger.debug("GEMS MASKED FILL") 

21 assert ( 

22 (torch.is_tensor(value) and value.ndim == 0) 

23 or isinstance(value, int) 

24 or isinstance(value, float) 

25 ), "masked_fill_ only supports a 0-dimensional value tensor" 

26 if torch.is_tensor(value): 

27 # Value can be a tensor or a scalar 

28 value = value.item() 

29 assert broadcastable_to( 

30 mask.shape, inp.shape 

31 ), "The shape of mask must be broadcastable with the shape of the underlying tensor" 

32 

33 if inp.ndim == 0: 

34 # inp is a single-value 

35 return ( 

36 torch.tensor(value, dtype=inp.dtype, device=inp.device) 

37 if mask.item() 

38 else inp.clone() 

39 ) 

40 

41 expand_mask = mask.expand(inp.shape) 

42 return masked_fill_kernel(inp, expand_mask, value) 

43 

44 

45def masked_fill_(inp, mask, value): 

46 logger.debug("GEMS MASKED FILL") 

47 assert ( 

48 (torch.is_tensor(value) and value.ndim == 0) 

49 or isinstance(value, int) 

50 or isinstance(value, float) 

51 ), "masked_fill_ only supports a 0-dimensional value tensor" 

52 if torch.is_tensor(value): 

53 # Value can be a tensor or a scalar 

54 value = value.item() 

55 assert broadcastable_to( 

56 mask.shape, inp.shape 

57 ), "The shape of mask must be broadcastable with the shape of the underlying tensor" 

58 

59 if inp.ndim == 0: 

60 # inp is a single-value 

61 if mask.item(): 

62 inp[()] = value 

63 return inp 

64 

65 expand_mask = mask.expand(inp.shape) 

66 return masked_fill_kernel(inp, expand_mask, value, out0=inp)