Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/amax.py: 0%

91 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-09 01:57 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7# from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import dim_compress, libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12from ..utils.block_size_utils import get_block_size_1d 

13 

14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

15 

16 

17@libentry() 

18@triton.jit 

19def amax_kernel_1( 

20 inp, 

21 mid, 

22 M, 

23 BLOCK_SIZE: tl.constexpr, 

24): 

25 pid = tle.program_id(0) 

26 

27 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

28 inp_ptrs = inp + offset 

29 mask = offset < M 

30 inp_val = tl.load(inp_ptrs, mask=mask, other=-float("inf")) 

31 amax_val = tl.max(inp_val) 

32 mid_ptr = mid + pid 

33 tl.store(mid_ptr, amax_val) 

34 

35 

36@libentry() 

37@triton.jit 

38def amax_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr): 

39 offset = tl.arange(0, BLOCK_MID) 

40 mid_ptrs = mid + offset 

41 mask = offset < mid_size 

42 mid_val = tl.load(mid_ptrs, mask=mask, other=-float("inf")) 

43 amax_val = tl.max(mid_val) 

44 tl.store(out, amax_val) 

45 

46 

47def heur_m_block_size(args): 

48 return triton.next_power_of_2(triton.cdiv(args["M"], 12)) # cluster_num 

49 

50 

51def heur_n_block_size(args): 

52 import builtins 

53 

54 return builtins.min(triton.next_power_of_2(args["N"]), 8192) 

55 

56 

57@libentry() 

58# @triton.autotune(configs=runtime.get_tuned_config("amax"), key=["M", "N"]) 

59@triton.heuristics( 

60 values={ 

61 "BLOCK_M": heur_m_block_size, 

62 "BLOCK_N": heur_n_block_size, 

63 }, 

64) 

65@triton.jit 

66def amax_kernel( 

67 inp, 

68 out, 

69 M, 

70 N, 

71 BLOCK_M: tl.constexpr, 

72 BLOCK_N: tl.constexpr, 

73): 

74 # Map the program id to the row of inp it should compute. 

75 pid = tle.program_id(0) 

76 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

77 inp = inp + rows * N 

78 out = out + rows 

79 row_mask = rows < M 

80 

81 _all = tl.full([BLOCK_M, BLOCK_N], value=-float("inf"), dtype=tl.float32) 

82 for off in range(0, N, BLOCK_N): 

83 cols = off + tl.arange(0, BLOCK_N)[None, :] 

84 col_mask = cols < N 

85 mask = row_mask and col_mask 

86 

87 a = tl.load(inp + cols, mask, other=-float("inf")).to(tl.float32) 

88 a = tl.where(mask, a, -float("inf")) 

89 _all = tl.maximum(_all, a) 

90 all = tl.max(_all, axis=1)[:, None] 

91 tl.store(out, all, row_mask) 

92 

93 

94def amax(inp, dim=None, keepdim=False): 

95 logger.debug("GEMS AMAX") 

96 if dim is None or len(dim) == 0: 

97 M = inp.numel() 

98 # block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

99 block_size = get_block_size_1d(M, inp.element_size()) 

100 mid_size = triton.cdiv(M, block_size) 

101 block_mid = triton.next_power_of_2(mid_size) 

102 dtype = inp.dtype 

103 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

104 if not keepdim: 

105 out = torch.empty([], dtype=dtype, device=inp.device) 

106 else: 

107 shape = list(inp.shape) 

108 for i in range(0, inp.dim()): 

109 shape[i] = 1 

110 out = torch.empty(shape, dtype=dtype, device=inp.device) 

111 with torch_device_fn.device(inp.device): 

112 amax_kernel_1[(mid_size, 1)]( 

113 inp, mid, M, block_size, buffer_size_limit=2048 

114 ) 

115 amax_kernel_2[(1, 1)]( 

116 mid, out, mid_size, block_mid, buffer_size_limit=2048 

117 ) # max block size is 128k, so mid does not requires int64 index 

118 return out 

119 else: 

120 if isinstance(dim, int): 

121 dim = [dim] 

122 assert ((i >= -inp.ndim and i < inp.ndim) for i in dim), "Invalid dim" 

123 dtype = inp.dtype 

124 

125 shape = list(inp.shape) 

126 dim = [d % inp.ndim for d in dim] 

127 inp = dim_compress(inp, dim) 

128 N = 1 

129 for i in dim: 

130 N *= shape[i] 

131 shape[i] = 1 

132 M = inp.numel() // N 

133 

134 out = torch.empty(shape, dtype=dtype, device=inp.device) 

135 

136 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

137 with torch_device_fn.device(inp.device): 

138 amax_kernel[grid](inp, out, M, N, buffer_size_limit=2048) 

139 if not keepdim: 

140 out = out.squeeze(dim=dim) 

141 return out