Coverage for src/flag_gems/runtime/backend/_hygon/ops/all.py: 0%

103 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-07 22:33 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import dim_compress, libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13logger = logging.getLogger(__name__) 

14# torch.all: Tests if all elements in input evaluate to True. If the dtype of input 

15# is not BOOL, then test if all elements in input evaluate to non-zero value 

16# In triton function, test if all elements in input evaluate to non-zero value is ok. 

17 

18 

19@triton.jit 

20def reduce_all(a, b): 

21 return a and b 

22 

23 

24@libentry() 

25@triton.autotune(configs=runtime.get_tuned_config("all"), key=["M", "N"]) 

26@triton.jit 

27def all_kernel_dim( 

28 inp, 

29 out, 

30 M, 

31 N, 

32 BLOCK_M: tl.constexpr, 

33 BLOCK_N: tl.constexpr, 

34): 

35 # Map the program id to the row of inp it should compute. 

36 pid = tle.program_id(0) 

37 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

38 inp = inp + rows * N 

39 out = out + rows 

40 row_mask = rows < M 

41 

42 _all = tl.full([BLOCK_M, BLOCK_N], value=1, dtype=tl.int1) 

43 for off in range(0, N, BLOCK_N): 

44 cols = off + tl.arange(0, BLOCK_N)[None, :] 

45 col_mask = cols < N 

46 mask = row_mask and col_mask 

47 

48 a = tl.load(inp + cols, mask, other=1.0) 

49 _all = _all and (a != 0) 

50 all = tl.reduce(_all, axis=1, combine_fn=reduce_all) 

51 tl.store(out, all[:, None], row_mask) 

52 

53 

54@libentry() 

55@triton.jit 

56def all_kernel_1( 

57 inp, 

58 mid, 

59 n_elements, 

60 mid_size, 

61 BLOCK_SIZE: tl.constexpr, 

62): 

63 pid = tle.program_id(0) 

64 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

65 inp_ptrs = inp + offset 

66 mask = offset < n_elements 

67 inp_val = tl.load(inp_ptrs, mask=mask, other=1.0) 

68 all_val = tl.reduce(inp_val != 0, axis=0, combine_fn=reduce_all) 

69 mid_ptr = mid + pid 

70 tl.store(mid_ptr, all_val) 

71 

72 

73@libentry() 

74@triton.jit 

75def all_kernel_2(mid, out, MID_SIZE, BLOCK_MID: tl.constexpr): 

76 offset = tl.arange(0, BLOCK_MID) 

77 mid_ptrs = mid + offset 

78 mask = offset < MID_SIZE 

79 mid_val = tl.load(mid_ptrs, mask=mask, other=1).to(tl.int1) 

80 all_val = tl.reduce(mid_val, axis=0, combine_fn=reduce_all) 

81 tl.store(out, all_val) 

82 

83 

84def all(inp): 

85 logger.debug("GEMS ALL") 

86 n_elements = inp.numel() 

87 block_size = triton.next_power_of_2(math.ceil(math.sqrt(n_elements))) 

88 mid_size = triton.cdiv(n_elements, block_size) 

89 block_mid = triton.next_power_of_2(mid_size) 

90 

91 mid = torch.empty((mid_size,), dtype=torch.bool, device=inp.device) 

92 out = torch.empty([], dtype=torch.bool, device=inp.device) 

93 

94 with torch_device_fn.device(inp.device): 

95 all_kernel_1[(mid_size, 1)](inp, mid, n_elements, mid_size, block_size) 

96 all_kernel_2[(1, 1)](mid, out, mid_size, block_mid) 

97 

98 return out 

99 

100 

101def all_dim(inp, dim=None, keepdim=False): 

102 logger.debug("GEMS ALL DIM") 

103 shape = list(inp.shape) 

104 if dim is None: 

105 out = all(inp) 

106 if keepdim: 

107 out = torch.reshape(out, [1] * inp.ndim) 

108 else: 

109 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

110 dim = dim % inp.ndim 

111 inp = dim_compress(inp, dim) 

112 N = shape[dim] 

113 shape[dim] = 1 

114 M = inp.numel() // N 

115 

116 out = torch.empty(shape, dtype=torch.bool, device=inp.device) 

117 

118 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

119 with torch_device_fn.device(inp.device): 

120 all_kernel_dim[grid](inp, out, M, N) 

121 if not keepdim: 

122 out = out.squeeze(dim=dim) 

123 return out 

124 

125 

126def all_dims(inp, dim=None, keepdim=False): 

127 logger.debug("GEMS ALL DIMS") 

128 

129 if dim is None or isinstance(dim, int): 

130 return all_dim(inp, dim=dim, keepdim=keepdim) 

131 assert ((i >= -inp.ndim and i < inp.ndim) for i in dim), "Invalid dim" 

132 

133 shape = list(inp.shape) 

134 dim = [d % inp.ndim for d in dim] 

135 inp = dim_compress(inp, dim) 

136 N = 1 

137 for i in dim: 

138 N *= shape[i] 

139 shape[i] = 1 

140 M = inp.numel() // N 

141 

142 out = torch.empty(shape, dtype=torch.bool, device=inp.device) 

143 

144 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

145 with torch_device_fn.device(inp.device): 

146 all_kernel_dim[grid](inp, out, M, N) 

147 if not keepdim: 

148 out = out.squeeze(dim=dim) 

149 return out