Coverage for src/flag_gems/runtime/backend/_cambricon/ops/any.py: 0%

96 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-20 02:31 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import dim_compress, libentry 

10 

11from ..utils import TOTAL_CORE_NUM, cfggen_reduce_op2 

12 

13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

14# torch.any: Tests if any elements in input evaluate to True. If the dtype of input 

15# is not BOOL, then test if any elements in input evaluate to non-zero value 

16# In triton function, test if any elements in input evaluate to non-zero value is ok. 

17 

18 

19@triton.jit 

20def reduce_any(a, b): 

21 return a or b 

22 

23 

24@libentry() 

25@triton.autotune(configs=runtime.get_tuned_config("any"), key=["M", "N"]) 

26@triton.jit 

27def any_kernel_dim( 

28 inp, 

29 out, 

30 M, 

31 N, 

32 BLOCK_M: tl.constexpr, 

33 BLOCK_N: tl.constexpr, 

34): 

35 # Map the program id to the row of inp it should compute. 

36 pid = tl.program_id(0) 

37 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

38 inp = inp + rows * N 

39 out = out + rows 

40 row_mask = rows < M 

41 

42 _any = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.int1) 

43 for off in range(0, N, BLOCK_N): 

44 cols = off + tl.arange(0, BLOCK_N)[None, :] 

45 col_mask = cols < N 

46 mask = row_mask and col_mask 

47 

48 a = tl.load(inp + cols, mask, other=0.0) 

49 _any = _any or (a != 0) 

50 any = tl.reduce(_any, axis=1, combine_fn=reduce_any) 

51 tl.store(out, any[:, None], row_mask) 

52 

53 

54@libentry() 

55@triton.autotune(configs=cfggen_reduce_op2(), key=["M"]) 

56@triton.jit 

57def any_kernel_1( 

58 inp, 

59 out, 

60 M, 

61 BLOCK_SIZE: tl.constexpr, 

62 ITER_NUM: tl.constexpr, 

63): 

64 pid = tl.program_id(0) 

65 num_jobs = tl.num_programs(axis=0) 

66 block_start = pid * BLOCK_SIZE 

67 step = num_jobs * BLOCK_SIZE 

68 _tmp = tl.zeros([BLOCK_SIZE], dtype=tl.int1) 

69 block_start = block_start.to(tl.int64) 

70 for off in range(block_start, M, step): 

71 offset = off + tl.arange(0, BLOCK_SIZE) 

72 mask = offset < M 

73 inp_val = tl.load(inp + offset, mask=mask, other=0.0) 

74 _tmp = _tmp or (inp_val != 0) 

75 

76 # Reset to original reduce programming mode after optimizing the tl.reduce. 

77 for x in tl.static_range(1, int(ITER_NUM), 1): 

78 _tmp[: BLOCK_SIZE // (2**x)] = ( 

79 _tmp[: BLOCK_SIZE // (2**x)] 

80 or _tmp[BLOCK_SIZE // (2**x) : (BLOCK_SIZE // (2**x)) * 2] 

81 ) 

82 

83 tl.atomic_or(out, _tmp[0].to(tl.int32)) 

84 

85 

86def any(inp): 

87 logger.debug("GEMS_CAMBRICON ANY") 

88 M = inp.numel() 

89 grid = lambda meta: (min(triton.cdiv(M, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) 

90 

91 out = torch.zeros([], dtype=torch.int32, device=inp.device) 

92 

93 with torch_device_fn.device(inp.device): 

94 any_kernel_1[grid](inp, out, M) 

95 

96 return out.to(torch.bool) 

97 

98 

99def any_dim(inp, dim=None, keepdim=False): 

100 logger.debug("GEMS_CAMBRICON ANY DIM") 

101 shape = list(inp.shape) 

102 if dim is None: 

103 out = any(inp) 

104 if keepdim: 

105 out = torch.reshape(out, [1] * inp.ndim) 

106 else: 

107 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

108 dim = dim % inp.ndim 

109 inp = dim_compress(inp, dim) 

110 N = shape[dim] 

111 shape[dim] = 1 

112 M = inp.numel() // N 

113 

114 out = torch.empty(shape, dtype=torch.bool, device=inp.device) 

115 

116 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

117 with torch_device_fn.device(inp.device): 

118 any_kernel_dim[grid](inp, out, M, N) 

119 if not keepdim: 

120 out = out.squeeze(dim=dim) 

121 return out 

122 

123 

124def any_dims(inp, dim=None, keepdim=False): 

125 logger.debug("GEMS_CAMBRICON ANY DIMS") 

126 

127 if dim is None or isinstance(dim, int): 

128 return any_dim(inp, dim=dim, keepdim=keepdim) 

129 assert ((i >= -inp.ndim and i < inp.ndim) for i in dim), "Invalid dim" 

130 

131 shape = list(inp.shape) 

132 dim = [d % inp.ndim for d in dim] 

133 inp = dim_compress(inp, dim) 

134 N = 1 

135 for i in dim: 

136 N *= shape[i] 

137 shape[i] = 1 

138 M = inp.numel() // N 

139 

140 out = torch.empty(shape, dtype=torch.bool, device=inp.device) 

141 

142 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

143 with torch_device_fn.device(inp.device): 

144 any_kernel_dim[grid](inp, out, M, N) 

145 if not keepdim: 

146 out = out.squeeze(dim=dim) 

147 return out