Coverage for src/flag_gems/ops/nonzero.py: 72%

46 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-26 15:32 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15@libentry() 

16@triton.heuristics(runtime.get_heuristic_config("elementwise_generic")) 

17@triton.jit 

18def nonzero_kernel( 

19 inp, 

20 prefix_sum, 

21 out, 

22 n_elements, 

23 shape, 

24 ndim: tl.constexpr, 

25 BLOCK_SIZE: tl.constexpr, 

26): 

27 pid = tle.program_id(0) 

28 

29 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

30 mask = offset < n_elements 

31 

32 inp_vals = tl.load(inp + offset, mask=mask).to(tl.int1) 

33 out_offset = tl.load(prefix_sum + offset, mask=mask) - 1 

34 

35 nonzero_mask = mask and inp_vals # noqa 

36 

37 idx_flat = offset 

38 for dim in range(ndim - 1, -1, -1): 

39 dim_size = tl.load(shape + dim) 

40 remainder = idx_flat % dim_size 

41 idx_flat //= dim_size 

42 tl.store(out + out_offset * ndim + dim, remainder, mask=nonzero_mask) 

43 

44 

45def nonzero(inp, *, as_tuple=False): 

46 logger.debug("GEMS NONZERO") 

47 

48 inp_ndim = inp.ndim 

49 

50 inp = inp.contiguous() 

51 n_elements = inp.numel() 

52 inp_view = inp.view(n_elements) 

53 

54 shape = torch.tensor(inp.shape, dtype=torch.int32, device=inp.device) 

55 

56 inp_bool = inp_view 

57 if inp_view.dtype != torch.bool: 

58 inp_bool = inp_view != 0 

59 

60 prefix_sum = inp_bool.cumsum(axis=0) 

61 

62 num_nonzeros = n_elements 

63 out = torch.empty(num_nonzeros, inp_ndim, dtype=torch.int64, device=inp.device) 

64 

65 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

66 with torch_device_fn.device(inp.device): 

67 nonzero_kernel[grid](inp_bool, prefix_sum, out, n_elements, shape, inp_ndim) 

68 

69 num_nonzeros = prefix_sum[n_elements - 1].item() 

70 out = out[0:num_nonzeros] 

71 

72 if as_tuple: 

73 return torch.unbind(out, dim=0) 

74 else: 

75 return out