Coverage for src/flag_gems/runtime/backend/_metax/ops/nonzero.py: 0%
126 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1import importlib
2import logging
3import os
4from typing import Any, Callable, List, Mapping, Tuple
6import torch
8from flag_gems.utils.code_cache import code_cache_dir
9from flag_gems.utils.code_utils import IndentedBuffer
11logger = logging.getLogger("flag_gems." + __name__)
14def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
15 code.writeline("import triton")
16 code.writeline("import triton.language as tl")
17 code.writeline("from flag_gems.utils import libentry, libtuner")
18 code.writeline("from flag_gems.utils import triton_lang_extension as tle")
19 code.writeline("from flag_gems import runtime")
20 code.writeline("from flag_gems.runtime import torch_device_fn")
22 code.newline()
23 code.newline()
25 return code
28def generate_nonzero_kernel(
29 rank: int,
30 kernel_name: str,
31 code: IndentedBuffer,
32) -> IndentedBuffer:
33 # the decorators
34 code.writeline("@libentry()")
35 code.writeline(
36 "@triton.heuristics(runtime.get_heuristic_config('elementwise_generic'))"
37 )
38 code.writeline("@triton.jit")
40 # signature
41 code.writeline(f"def {kernel_name}(")
42 with code.indent():
43 if rank > 0:
44 code.writeline("inp,")
45 code.writeline("prefix_sum,")
46 code.writeline("out,")
47 code.writeline("n_elements: tl.constexpr,")
48 code.writeline("ndim: tl.constexpr,")
50 shape_args = ", ".join(f"dim{i}_size" for i in range(rank))
51 code.writeline(f"{shape_args}, # shape for src")
53 code.writeline("BLOCK_SIZE: tl.constexpr,")
55 code.writeline("):")
57 # Kernel Code
58 with code.indent():
59 code.writeline("pid = tle.program_id(0)")
60 code.writeline("offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)")
61 code.writeline("mask = offset < n_elements")
62 code.newline()
64 code.writeline("inp_vals = tl.load(inp + offset, mask=mask)")
65 code.writeline("out_offset = tl.load(prefix_sum + offset, mask=mask) - 1")
66 code.writeline("nonzero_mask = mask and inp_vals == True # noqa")
67 code.writeline("idx_flat = offset")
68 code.newline()
70 for i in range(rank - 1, -1, -1):
71 code.writeline(f"remainder = idx_flat % dim{i}_size")
72 code.writeline(f"idx_flat //= dim{i}_size")
73 code.writeline(
74 f"tl.store(out + out_offset * ndim + {i}, remainder, mask=nonzero_mask)"
75 )
76 code.newline()
78 code.newline()
79 code.newline()
80 return code
83def parameter_for_wrapper() -> str:
84 # inp_bool, prefix_sum, out, n_elements, inp_ndim, shape
85 parameters: List[str] = []
86 parameters.append("inp_bool")
87 parameters.append("prefix_sum")
88 parameters.append("out")
89 parameters.append("n_elements")
90 parameters.append("inp_ndim")
91 parameters.append("shape")
93 return ", ".join(parameters)
96def generate_destination_passing_wrapper(
97 rank: int,
98 wrapper_name: str,
99 kernel_name: str,
100 code: IndentedBuffer,
101) -> IndentedBuffer:
102 parameters: str = parameter_for_wrapper()
103 wrapper_signature: str = f"def {wrapper_name} ({parameters}):"
104 code.writeline(wrapper_signature)
106 with code.indent():
107 code.writeline(
108 'grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)'
109 )
110 kernel_launch: str = f"{kernel_name}[grid]("
111 code.writeline(kernel_launch)
112 with code.indent():
113 code.writeline("inp_bool, prefix_sum, out, n_elements, inp_ndim, ")
114 if rank > 0:
115 s = ", ".join(f"shape[{i}]" for i in range(rank))
116 code.writeline(f"{s}")
118 code.writeline(")")
119 code.writeline("return out")
121 return code
124def generate_code(
125 inputs: Tuple[Any],
126 wrapper_name: str,
127 kernel_name: str,
128 code: IndentedBuffer,
129) -> IndentedBuffer:
130 # inputs: [inp_bool, prefix_sum, out, n_elements, inp_ndim, shape]
131 shape = inputs[-1]
132 rank = len(shape)
133 code = generate_imports(code)
134 code = generate_nonzero_kernel(rank, kernel_name, code)
135 code = generate_destination_passing_wrapper(rank, wrapper_name, kernel_name, code)
136 return code
139class NonzeroFunction:
140 def __init__(self):
141 self.pid = os.getpid()
142 self.overloads: Mapping[str, Callable] = {}
144 def __call__(self, *args, **kwargs):
145 key = f"{self.arg_key(*args)}"
146 if key in self.overloads:
147 overload = self.overloads[key]
148 else:
149 code = IndentedBuffer()
150 code = generate_code(
151 args,
152 "_nonzero_wrapper",
153 "_nonzero_jit_function",
154 code,
155 )
157 file_name = f"nonzero_rank_{key}_pid_{self.pid}.py"
159 with open(code_cache_dir() / file_name, "wt", encoding="utf-8") as f:
160 f.write(code.getvalue())
162 # load
163 spec = importlib.util.spec_from_file_location(
164 f"_gen_module_rank_{key}_pid_{self.pid}",
165 f.name,
166 )
168 m = importlib.util.module_from_spec(spec)
169 spec.loader.exec_module(m)
170 overload = getattr(m, "_nonzero_wrapper")
171 self.overloads[key] = overload
173 return overload(*args, **kwargs)
175 def arg_key(self, *args):
176 # args: [inp_bool, prefix_sum, out, n_elements, inp_ndim, shape]
177 return args[-2]
180_nonzero_func = NonzeroFunction()
183def nonzero(inp, *, as_tuple=False):
184 logger.debug("METAX GEMS NONZERO")
186 assert len(inp.shape) > 0, "Invalid input shape, input dimension must > 0"
187 inp_ndim = inp.ndim
188 inp = inp.contiguous()
189 n_elements = inp.numel()
190 inp_view = inp.view(n_elements)
192 shape = inp.shape
194 inp_bool = inp_view
195 if inp_view.dtype != torch.bool:
196 inp_bool = inp_view != 0
198 prefix_sum = inp_bool.cumsum(axis=0)
200 num_nonzeros = n_elements
201 out = torch.empty(num_nonzeros, inp_ndim, dtype=torch.int64, device=inp.device)
202 _nonzero_func(inp_bool, prefix_sum, out, n_elements, inp_ndim, shape)
204 num_nonzeros = prefix_sum[n_elements - 1].item()
205 out = out[0:num_nonzeros]
207 if as_tuple:
208 return torch.unbind(out, dim=0)
209 else:
210 return out