Coverage for src/flag_gems/ops/masked_fill.py: 33%
33 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import broadcastable_to, pointwise_dynamic
9logger = logging.getLogger(__name__)
12@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, "NO_OPMATH")])
13@triton.jit
14def masked_fill_kernel(inp, expand_mask, value):
15 inp = tl.where(expand_mask == 1, value, inp)
16 return inp
19def masked_fill(inp, mask, value):
20 logger.debug("GEMS MASKED FILL")
21 assert (
22 (torch.is_tensor(value) and value.ndim == 0)
23 or isinstance(value, int)
24 or isinstance(value, float)
25 ), "masked_fill_ only supports a 0-dimensional value tensor"
26 if torch.is_tensor(value):
27 # Value can be a tensor or a scalar
28 value = value.item()
29 assert broadcastable_to(
30 mask.shape, inp.shape
31 ), "The shape of mask must be broadcastable with the shape of the underlying tensor"
33 if inp.ndim == 0:
34 # inp is a single-value
35 return (
36 torch.tensor(value, dtype=inp.dtype, device=inp.device)
37 if mask.item()
38 else inp.clone()
39 )
41 expand_mask = mask.expand(inp.shape)
42 return masked_fill_kernel(inp, expand_mask, value)
45def masked_fill_(inp, mask, value):
46 logger.debug("GEMS MASKED FILL")
47 assert (
48 (torch.is_tensor(value) and value.ndim == 0)
49 or isinstance(value, int)
50 or isinstance(value, float)
51 ), "masked_fill_ only supports a 0-dimensional value tensor"
52 if torch.is_tensor(value):
53 # Value can be a tensor or a scalar
54 value = value.item()
55 assert broadcastable_to(
56 mask.shape, inp.shape
57 ), "The shape of mask must be broadcastable with the shape of the underlying tensor"
59 if inp.ndim == 0:
60 # inp is a single-value
61 if mask.item():
62 inp[()] = value
63 return inp
65 expand_mask = mask.expand(inp.shape)
66 return masked_fill_kernel(inp, expand_mask, value, out0=inp)