Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/nonzero.py: 0%

47 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-07 22:33 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7# from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13 

14 

15def nonzero_kernel_heur_block_size(args): 

16 return triton.next_power_of_2(triton.cdiv(args["n_elements"], 12)) # cluster_num 

17 

18 

19@libentry() 

20# @triton.autotune( 

21# configs=runtime.get_tuned_config("nonzero"), 

22# key=[ 

23# "n_elements", 

24# ], 

25# ) 

26@triton.heuristics( 

27 values={ 

28 "BLOCK_SIZE": nonzero_kernel_heur_block_size, 

29 }, 

30) 

31@triton.jit 

32def nonzero_kernel( 

33 inp, 

34 prefix_sum, 

35 out, 

36 n_elements: tl.constexpr, 

37 shape, 

38 ndim: tl.constexpr, 

39 BLOCK_SIZE: tl.constexpr, 

40): 

41 pid = tle.program_id(0) 

42 

43 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

44 mask = offset < n_elements 

45 

46 inp_vals = tl.load(inp + offset, mask=mask).to(tl.int1) 

47 out_offset = tl.load(prefix_sum + offset, mask=mask) - 1 

48 

49 nonzero_mask = mask and inp_vals # noqa 

50 

51 idx_flat = offset 

52 for dim in range(ndim - 1, -1, -1): 

53 dim_size = tl.load(shape + dim) 

54 remainder = idx_flat % dim_size 

55 idx_flat //= dim_size 

56 tl.store(out + out_offset * ndim + dim, remainder, mask=nonzero_mask) 

57 

58 

59def nonzero(inp, *, as_tuple=False): 

60 logger.debug("GEMS NONZERO") 

61 

62 inp_ndim = inp.ndim 

63 

64 inp = inp.contiguous() 

65 n_elements = inp.numel() 

66 inp_view = inp.view(n_elements) 

67 

68 shape = torch.tensor(inp.shape, dtype=torch.int32, device=inp.device) 

69 

70 inp_bool = inp_view 

71 if inp_view.dtype != torch.bool: 

72 inp_bool = inp_view != 0 

73 

74 prefix_sum = inp_bool.cumsum(axis=0) 

75 

76 num_nonzeros = n_elements 

77 out = torch.empty(num_nonzeros, inp_ndim, dtype=torch.int64, device=inp.device) 

78 

79 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

80 with torch_device_fn.device(inp.device): 

81 nonzero_kernel[grid]( 

82 inp_bool, 

83 prefix_sum, 

84 out, 

85 n_elements, 

86 shape, 

87 inp_ndim, 

88 isCloseUnrollControl=True, 

89 is_use_mask_zero=True, 

90 ) 

91 

92 num_nonzeros = prefix_sum[n_elements - 1].item() 

93 out = out[0:num_nonzeros] 

94 

95 if as_tuple: 

96 return torch.unbind(out, dim=0) 

97 else: 

98 return out