Coverage for src/flag_gems/runtime/backend/_metax/ops/nonzero.py: 0%

126 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-22 16:54 +0800

1import importlib 

2import logging 

3import os 

4from typing import Any, Callable, List, Mapping, Tuple 

5 

6import torch 

7 

8from flag_gems.utils.code_cache import code_cache_dir 

9from flag_gems.utils.code_utils import IndentedBuffer 

10 

11logger = logging.getLogger("flag_gems." + __name__) 

12 

13 

14def generate_imports(code: IndentedBuffer) -> IndentedBuffer: 

15 code.writeline("import triton") 

16 code.writeline("import triton.language as tl") 

17 code.writeline("from flag_gems.utils import libentry, libtuner") 

18 code.writeline("from flag_gems.utils import triton_lang_extension as tle") 

19 code.writeline("from flag_gems import runtime") 

20 code.writeline("from flag_gems.runtime import torch_device_fn") 

21 

22 code.newline() 

23 code.newline() 

24 

25 return code 

26 

27 

28def generate_nonzero_kernel( 

29 rank: int, 

30 kernel_name: str, 

31 code: IndentedBuffer, 

32) -> IndentedBuffer: 

33 # the decorators 

34 code.writeline("@libentry()") 

35 code.writeline( 

36 "@triton.heuristics(runtime.get_heuristic_config('elementwise_generic'))" 

37 ) 

38 code.writeline("@triton.jit") 

39 

40 # signature 

41 code.writeline(f"def {kernel_name}(") 

42 with code.indent(): 

43 if rank > 0: 

44 code.writeline("inp,") 

45 code.writeline("prefix_sum,") 

46 code.writeline("out,") 

47 code.writeline("n_elements: tl.constexpr,") 

48 code.writeline("ndim: tl.constexpr,") 

49 

50 shape_args = ", ".join(f"dim{i}_size" for i in range(rank)) 

51 code.writeline(f"{shape_args}, # shape for src") 

52 

53 code.writeline("BLOCK_SIZE: tl.constexpr,") 

54 

55 code.writeline("):") 

56 

57 # Kernel Code 

58 with code.indent(): 

59 code.writeline("pid = tle.program_id(0)") 

60 code.writeline("offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)") 

61 code.writeline("mask = offset < n_elements") 

62 code.newline() 

63 

64 code.writeline("inp_vals = tl.load(inp + offset, mask=mask)") 

65 code.writeline("out_offset = tl.load(prefix_sum + offset, mask=mask) - 1") 

66 code.writeline("nonzero_mask = mask and inp_vals == True # noqa") 

67 code.writeline("idx_flat = offset") 

68 code.newline() 

69 

70 for i in range(rank - 1, -1, -1): 

71 code.writeline(f"remainder = idx_flat % dim{i}_size") 

72 code.writeline(f"idx_flat //= dim{i}_size") 

73 code.writeline( 

74 f"tl.store(out + out_offset * ndim + {i}, remainder, mask=nonzero_mask)" 

75 ) 

76 code.newline() 

77 

78 code.newline() 

79 code.newline() 

80 return code 

81 

82 

83def parameter_for_wrapper() -> str: 

84 # inp_bool, prefix_sum, out, n_elements, inp_ndim, shape 

85 parameters: List[str] = [] 

86 parameters.append("inp_bool") 

87 parameters.append("prefix_sum") 

88 parameters.append("out") 

89 parameters.append("n_elements") 

90 parameters.append("inp_ndim") 

91 parameters.append("shape") 

92 

93 return ", ".join(parameters) 

94 

95 

96def generate_destination_passing_wrapper( 

97 rank: int, 

98 wrapper_name: str, 

99 kernel_name: str, 

100 code: IndentedBuffer, 

101) -> IndentedBuffer: 

102 parameters: str = parameter_for_wrapper() 

103 wrapper_signature: str = f"def {wrapper_name} ({parameters}):" 

104 code.writeline(wrapper_signature) 

105 

106 with code.indent(): 

107 code.writeline( 

108 'grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)' 

109 ) 

110 kernel_launch: str = f"{kernel_name}[grid](" 

111 code.writeline(kernel_launch) 

112 with code.indent(): 

113 code.writeline("inp_bool, prefix_sum, out, n_elements, inp_ndim, ") 

114 if rank > 0: 

115 s = ", ".join(f"shape[{i}]" for i in range(rank)) 

116 code.writeline(f"{s}") 

117 

118 code.writeline(")") 

119 code.writeline("return out") 

120 

121 return code 

122 

123 

124def generate_code( 

125 inputs: Tuple[Any], 

126 wrapper_name: str, 

127 kernel_name: str, 

128 code: IndentedBuffer, 

129) -> IndentedBuffer: 

130 # inputs: [inp_bool, prefix_sum, out, n_elements, inp_ndim, shape] 

131 shape = inputs[-1] 

132 rank = len(shape) 

133 code = generate_imports(code) 

134 code = generate_nonzero_kernel(rank, kernel_name, code) 

135 code = generate_destination_passing_wrapper(rank, wrapper_name, kernel_name, code) 

136 return code 

137 

138 

139class NonzeroFunction: 

140 def __init__(self): 

141 self.pid = os.getpid() 

142 self.overloads: Mapping[str, Callable] = {} 

143 

144 def __call__(self, *args, **kwargs): 

145 key = f"{self.arg_key(*args)}" 

146 if key in self.overloads: 

147 overload = self.overloads[key] 

148 else: 

149 code = IndentedBuffer() 

150 code = generate_code( 

151 args, 

152 "_nonzero_wrapper", 

153 "_nonzero_jit_function", 

154 code, 

155 ) 

156 

157 file_name = f"nonzero_rank_{key}_pid_{self.pid}.py" 

158 

159 with open(code_cache_dir() / file_name, "wt", encoding="utf-8") as f: 

160 f.write(code.getvalue()) 

161 

162 # load 

163 spec = importlib.util.spec_from_file_location( 

164 f"_gen_module_rank_{key}_pid_{self.pid}", 

165 f.name, 

166 ) 

167 

168 m = importlib.util.module_from_spec(spec) 

169 spec.loader.exec_module(m) 

170 overload = getattr(m, "_nonzero_wrapper") 

171 self.overloads[key] = overload 

172 

173 return overload(*args, **kwargs) 

174 

175 def arg_key(self, *args): 

176 # args: [inp_bool, prefix_sum, out, n_elements, inp_ndim, shape] 

177 return args[-2] 

178 

179 

180_nonzero_func = NonzeroFunction() 

181 

182 

183def nonzero(inp, *, as_tuple=False): 

184 logger.debug("METAX GEMS NONZERO") 

185 

186 assert len(inp.shape) > 0, "Invalid input shape, input dimension must > 0" 

187 inp_ndim = inp.ndim 

188 inp = inp.contiguous() 

189 n_elements = inp.numel() 

190 inp_view = inp.view(n_elements) 

191 

192 shape = inp.shape 

193 

194 inp_bool = inp_view 

195 if inp_view.dtype != torch.bool: 

196 inp_bool = inp_view != 0 

197 

198 prefix_sum = inp_bool.cumsum(axis=0) 

199 

200 num_nonzeros = n_elements 

201 out = torch.empty(num_nonzeros, inp_ndim, dtype=torch.int64, device=inp.device) 

202 _nonzero_func(inp_bool, prefix_sum, out, n_elements, inp_ndim, shape) 

203 

204 num_nonzeros = prefix_sum[n_elements - 1].item() 

205 out = out[0:num_nonzeros] 

206 

207 if as_tuple: 

208 return torch.unbind(out, dim=0) 

209 else: 

210 return out