Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/masked_fill.py: 0%
84 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import broadcastable_to, libentry
8from flag_gems.utils import triton_lang_extension as tle
10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
13def masked_fill_kernel_heur_block_size(args):
14 return triton.next_power_of_2(triton.cdiv(args["N"], 12)) # cluster_num
17@libentry()
18# @triton.autotune(configs=runtime.get_tuned_config("masked_fill"), key=["N"])
19# @triton.heuristics(
20# values={
21# "BLOCK_SIZE": masked_fill_kernel_heur_block_size,
22# },
23# )
24@triton.jit
25def masked_fill_kernel(
26 inp, expand_mask, value, out, N: tl.constexpr, BLOCK_SIZE: tl.constexpr
27):
28 pid = tle.program_id(axis=0)
29 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
30 mask = offsets < N
32 fill_mask = tl.load(expand_mask + offsets, mask=mask, other=0).to(tl.int1)
33 cur_inp = tl.load(inp + offsets, mask=(not fill_mask) and mask, other=0)
34 out_offset_1 = tl.where((not fill_mask) and mask, offsets, -1)
35 tl.store(out + out_offset_1, cur_inp, (not fill_mask) and mask)
36 out_offset_2 = tl.where(fill_mask and mask, offsets, -1)
37 tl.store(out + out_offset_2, value, fill_mask and mask)
40def masked_fill_kernel_self_heur_block_size(args):
41 return triton.next_power_of_2(triton.cdiv(args["N"], 12)) # cluster_num
44@libentry()
45# @triton.autotune(configs=runtime.get_tuned_config("masked_fill"), key=["N"])
46# @triton.heuristics(
47# values={
48# "BLOCK_SIZE": masked_fill_kernel_self_heur_block_size,
49# },
50# )
51@triton.jit
52def masked_fill_kernel_self(inp, expand_mask, value, N, BLOCK_SIZE: tl.constexpr):
53 pid = tle.program_id(axis=0)
54 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
55 mask = offsets < N
57 fill_mask = tl.load(expand_mask + offsets, mask=mask, other=0).to(tl.int1)
58 tl.store(inp + offsets, value, fill_mask and mask)
61def masked_fill(inp, mask, value):
62 logger.debug("GEMS MASKED FILL")
63 assert (
64 (torch.is_tensor(value) and value.ndim == 0)
65 or isinstance(value, int)
66 or isinstance(value, float)
67 ), "masked_fill_ only supports a 0-dimensional value tensor"
68 if torch.is_tensor(value):
69 # Value can be a tensor or a scalar
70 value = value.item()
71 assert broadcastable_to(
72 mask.shape, inp.shape
73 ), "The shape of mask must be broadcastable with the shape of the underlying tensor"
75 if inp.ndim == 0:
76 # inp is a single-value
77 return (
78 torch.tensor(value, dtype=inp.dtype, device=inp.device)
79 if mask.item()
80 else inp.clone()
81 )
83 inp = inp.contiguous()
84 mask = mask.contiguous()
85 expand_mask = mask.expand(inp.shape)
86 out = torch.empty_like(inp, dtype=inp.dtype, device=inp.device)
88 N = inp.numel()
89 if N == 0:
90 return out
91 grid = 12
92 BLOCK_SIZE = triton.next_power_of_2(triton.cdiv(N, grid))
94 import os
96 os.environ["TRITONXPU_OTHER_SIM"] = "1"
97 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1"
98 masked_fill_kernel[grid,](
99 inp,
100 expand_mask.to(torch.int),
101 value,
102 out,
103 N,
104 BLOCK_SIZE,
105 isCloseUnrollControl=True,
106 buffer_size_limit=2048,
107 )
109 if "TRITONXPU_OTHER_SIM" in os.environ:
110 del os.environ["TRITONXPU_OTHER_SIM"]
111 if "TRITONXPU_STORE_MASK_SIM" in os.environ:
112 del os.environ["TRITONXPU_STORE_MASK_SIM"]
113 return out
116def masked_fill_(inp, mask, value):
117 logger.debug("GEMS MASKED FILL")
118 assert (
119 (torch.is_tensor(value) and value.ndim == 0)
120 or isinstance(value, int)
121 or isinstance(value, float)
122 ), "masked_fill_ only supports a 0-dimensional value tensor"
123 if torch.is_tensor(value):
124 # Value can be a tensor or a scalar
125 value = value.item()
126 assert broadcastable_to(
127 mask.shape, inp.shape
128 ), "The shape of mask must be broadcastable with the shape of the underlying tensor"
130 if inp.ndim == 0:
131 # inp is a single-value
132 if mask.item():
133 inp[()] = value
134 return inp
136 inp = inp.contiguous()
137 mask = mask.contiguous()
138 expand_mask = mask.expand(inp.shape)
140 N = inp.numel()
141 if N == 0:
142 return inp
144 import os
146 os.environ["TRITONXPU_OTHER_SIM"] = "1"
147 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1"
149 grid = 12
150 BLOCK_SIZE = triton.next_power_of_2(triton.cdiv(N, grid))
151 masked_fill_kernel_self[grid,](
152 inp, expand_mask.to(torch.int), value, N, BLOCK_SIZE, buffer_size_limit=2048
153 )
154 if "TRITONXPU_OTHER_SIM" in os.environ:
155 del os.environ["TRITONXPU_OTHER_SIM"]
156 if "TRITONXPU_STORE_MASK_SIM" in os.environ:
157 del os.environ["TRITONXPU_STORE_MASK_SIM"]
158 return inp