Coverage for src/flag_gems/runtime/backend/_cambricon/ops/masked_fill.py: 0%
73 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.utils import broadcastable_to, libentry, libtuner
10from ..utils import MAX_GRID_SIZE_X
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15@libentry()
16@libtuner(
17 configs=runtime.get_tuned_config("masked_fill"),
18 key=["N"],
19)
20@triton.jit
21def masked_fill_kernel(inp, expand_mask, value, out, N, BLOCK_SIZE: tl.constexpr):
22 pid = tl.program_id(axis=0)
23 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
24 mask = offsets < N
26 fill_mask = tl.load(expand_mask + offsets, mask=mask, other=0).to(tl.int1)
27 cur_inp = tl.load(inp + offsets, mask=(not fill_mask) and mask, other=0)
28 tl.store(out + offsets, cur_inp, (not fill_mask) and mask)
29 tl.store(out + offsets, value, fill_mask and mask)
32@libentry()
33@libtuner(
34 configs=runtime.get_tuned_config("masked_fill"),
35 key=["N"],
36)
37@triton.jit
38def masked_fill_kernel_self(inp, expand_mask, value, N, BLOCK_SIZE: tl.constexpr):
39 num_programs = tl.num_programs(0)
40 pid = tl.program_id(axis=0)
41 total_blocks = tl.cdiv(N, BLOCK_SIZE)
43 for block_idx in range(pid, total_blocks, num_programs):
44 offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
45 mask = offsets < N
47 fill_mask = tl.load(expand_mask + offsets, mask=mask, other=0).to(tl.int1)
48 cur_val = tl.full((BLOCK_SIZE,), value, dtype=inp.dtype.element_ty)
49 tl.store(inp + offsets, cur_val, fill_mask and mask)
52def masked_fill(inp, mask, value):
53 logger.debug("GEMS_CAMBRICON MASKED FILL")
54 assert (
55 (torch.is_tensor(value) and value.ndim == 0)
56 or isinstance(value, int)
57 or isinstance(value, float)
58 ), "masked_fill_ only supports a 0-dimensional value tensor"
59 if torch.is_tensor(value):
60 # Value can be a tensor or a scalar
61 value = value.item()
62 assert broadcastable_to(
63 mask.shape, inp.shape
64 ), "The shape of mask must be broadcastable with the shape of the underlying tensor"
66 if inp.ndim == 0:
67 # inp is a single-value
68 return (
69 torch.tensor(value, dtype=inp.dtype, device=inp.device)
70 if mask.item()
71 else inp.clone()
72 )
74 inp = inp.contiguous()
75 mask = mask.contiguous()
76 expand_mask = mask.expand(inp.shape)
77 out = torch.empty_like(inp, dtype=inp.dtype, device=inp.device)
79 N = inp.numel()
80 if N == 0:
81 return out
83 def gridfn(meta):
84 blocks = triton.cdiv(N, meta["BLOCK_SIZE"])
85 x = min(MAX_GRID_SIZE_X, blocks)
86 y = triton.cdiv(blocks, x)
87 return (x, y, 1)
89 masked_fill_kernel[gridfn](inp, expand_mask.to(torch.int), value, out, N)
90 return out
93def masked_fill_(inp, mask, value):
94 logger.debug("GEMS_CAMBRICON MASKED FILL")
95 assert (
96 (torch.is_tensor(value) and value.ndim == 0)
97 or isinstance(value, int)
98 or isinstance(value, float)
99 ), "masked_fill_ only supports a 0-dimensional value tensor"
100 if torch.is_tensor(value):
101 # Value can be a tensor or a scalar
102 value = value.item()
103 assert broadcastable_to(
104 mask.shape, inp.shape
105 ), "The shape of mask must be broadcastable with the shape of the underlying tensor"
107 if inp.ndim == 0:
108 # inp is a single-value
109 if mask.item():
110 inp[()] = value
111 return inp
113 inp = inp.contiguous()
114 mask = mask.contiguous()
115 expand_mask = mask.expand(inp.shape)
117 N = inp.numel()
118 if N == 0:
119 return inp
120 grid = lambda meta: (min(65535, triton.cdiv(N, meta["BLOCK_SIZE"])),)
121 masked_fill_kernel_self[grid](inp, expand_mask.to(torch.int), value, N)
122 return inp