Coverage for src/flag_gems/runtime/backend/_cambricon/ops/nonzero.py: 0%

54 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-12 02:21 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10 

11from ..utils import TOTAL_CORE_NUM 

12 

13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

14 

15 

16@libentry() 

17@triton.autotune( 

18 configs=runtime.get_tuned_config("nonzero"), 

19 key=[ 

20 "n_elements", 

21 ], 

22) 

23@triton.jit 

24def nonzero_kernel( 

25 inp, 

26 prefix_sum, 

27 out, 

28 n_elements, 

29 shape, 

30 ndim: tl.constexpr, 

31 BLOCK_SIZE: tl.constexpr, 

32): 

33 pid = tl.program_id(0) 

34 num_jobs = tl.num_programs(0) 

35 block_start = pid * BLOCK_SIZE 

36 step = num_jobs * BLOCK_SIZE 

37 for block_start_offset in range(block_start, n_elements, step): 

38 offset = block_start_offset + tl.arange(0, BLOCK_SIZE) 

39 mask = offset < n_elements 

40 

41 inp_vals = tl.load(inp + offset, mask=mask, other=0.0).to(tl.int1) 

42 nonzero_mask = mask and inp_vals 

43 out_row_offset = tl.load(prefix_sum + offset, mask=nonzero_mask) - 1 

44 out_col_offset = tl.arange(0, ndim) 

45 out_offsets = out_row_offset[:, None] * ndim + out_col_offset[None, :] 

46 out_vals = tl.zeros((BLOCK_SIZE, ndim), tl.int32) 

47 idx_flat = offset 

48 for dim in range(ndim - 1, -1, -1): 

49 dim_size = tl.load(shape + dim) 

50 remainder = idx_flat % dim_size 

51 idx_flat //= dim_size 

52 out_vals[:, dim] = remainder 

53 tl.store(out + out_offsets.to(tl.int32), out_vals, mask=nonzero_mask[:, None]) 

54 

55 

56def nonzero(inp, *, as_tuple=False): 

57 logger.debug("GEMS_CAMBRICON NONZERO") 

58 

59 inp_ndim = inp.ndim 

60 

61 inp = inp.contiguous() 

62 n_elements = inp.numel() 

63 inp_view = inp.view(n_elements) 

64 

65 shape = torch.tensor(inp.shape, dtype=torch.int32, device=inp.device) 

66 

67 inp_bool = inp_view 

68 if inp_view.dtype != torch.bool: 

69 inp_bool = inp_view != 0 

70 

71 prefix_sum = inp_bool.cumsum(axis=0) 

72 num_nonzeros = n_elements 

73 out = torch.empty(num_nonzeros, inp_ndim, dtype=torch.int64, device=inp.device) 

74 grid = lambda meta: ( 

75 min(triton.cdiv(n_elements, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM), 

76 ) 

77 with torch_device_fn.device(inp.device): 

78 nonzero_kernel[grid](inp_bool, prefix_sum, out, n_elements, shape, inp_ndim) 

79 

80 num_nonzeros = prefix_sum[n_elements - 1].item() 

81 out = out[0:num_nonzeros] 

82 

83 if as_tuple: 

84 return torch.unbind(out, dim=0) 

85 else: 

86 return out