Coverage for src/flag_gems/runtime/backend/_metax/ops/masked_fill.py: 0%
65 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7import flag_gems.runtime as runtime
8from flag_gems.utils import broadcastable_to, libentry, libtuner
9from flag_gems.utils import triton_lang_extension as tle
11logger = logging.getLogger("flag_gems." + __name__)
14@libentry()
15@libtuner(configs=runtime.get_tuned_config("masked_fill"), key=["N"])
16@triton.jit
17def masked_fill_kernel(inp, expand_mask, value, out, N, BLOCK_SIZE: tl.constexpr):
18 pid = tle.program_id(axis=0)
19 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
20 mask = offsets < N
22 fill_mask = tl.load(expand_mask + offsets, mask=mask, other=0).to(tl.int1)
23 cur_inp = tl.load(inp + offsets, mask=mask, other=0)
24 cur_inp = tl.where(fill_mask, value, cur_inp)
25 tl.store(out + offsets, cur_inp, mask)
28@libentry()
29@libtuner(configs=runtime.get_tuned_config("masked_fill"), key=["N"])
30@triton.jit
31def masked_fill_kernel_self(inp, expand_mask, value, N, BLOCK_SIZE: tl.constexpr):
32 pid = tle.program_id(axis=0)
33 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
34 mask = offsets < N
36 fill_mask = tl.load(expand_mask + offsets, mask=mask, other=0).to(tl.int1)
37 tl.store(inp + offsets, value, fill_mask and mask)
40def masked_fill(inp, mask, value):
41 logger.debug("METAX GEMS MASKED FILL")
42 assert (
43 (torch.is_tensor(value) and value.ndim == 0)
44 or isinstance(value, int)
45 or isinstance(value, float)
46 ), "masked_fill_ only supports a 0-dimensional value tensor"
47 if torch.is_tensor(value):
48 # Value can be a tensor or a scalar
49 value = value.item()
50 assert broadcastable_to(
51 mask.shape, inp.shape
52 ), "The shape of mask must be broadcastable with the shape of the underlying tensor"
54 if inp.ndim == 0:
55 # inp is a single-value
56 return (
57 torch.tensor(value, dtype=inp.dtype, device=inp.device)
58 if mask.item()
59 else inp.clone()
60 )
62 inp = inp.contiguous()
63 mask = mask.contiguous()
64 expand_mask = mask.expand(inp.shape)
65 out = torch.empty_like(inp, dtype=inp.dtype, device=inp.device)
67 N = inp.numel()
68 if N == 0:
69 return out
70 grid = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),)
71 masked_fill_kernel[grid](inp, expand_mask, value, out, N)
72 return out
75def masked_fill_(inp, mask, value):
76 logger.debug("METAX GEMS MASKED FILL")
77 assert (
78 (torch.is_tensor(value) and value.ndim == 0)
79 or isinstance(value, int)
80 or isinstance(value, float)
81 ), "masked_fill_ only supports a 0-dimensional value tensor"
82 if torch.is_tensor(value):
83 # Value can be a tensor or a scalar
84 value = value.item()
85 assert broadcastable_to(
86 mask.shape, inp.shape
87 ), "The shape of mask must be broadcastable with the shape of the underlying tensor"
89 if inp.ndim == 0:
90 # inp is a single-value
91 if mask.item():
92 inp[()] = value
93 return inp
95 inp = inp.contiguous()
96 mask = mask.contiguous()
97 expand_mask = mask.expand(inp.shape)
99 N = inp.numel()
100 if N == 0:
101 return inp
102 grid = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),)
103 masked_fill_kernel_self[grid](inp, expand_mask, value, N)
104 return inp