Coverage for src/flag_gems/ops/max.py: 59%

97 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-13 10:08 +0800

1import logging 

2import math 

3from collections import namedtuple 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems import runtime 

10from flag_gems.runtime import torch_device_fn 

11from flag_gems.utils import dim_compress, libentry, libtuner 

12from flag_gems.utils import triton_lang_extension as tle 

13from flag_gems.utils.limits import get_dtype_min 

14 

15logger = logging.getLogger(__name__) 

16 

17 

18@libentry() 

19@triton.jit 

20def max_kernel_1( 

21 inp, 

22 mid, 

23 M, 

24 BLOCK_SIZE: tl.constexpr, 

25): 

26 pid = tle.program_id(0) 

27 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

28 inp_ptrs = inp + offset 

29 mask = offset < M 

30 min_value = get_dtype_min(inp.type.element_ty) 

31 inp_val = tl.load(inp_ptrs, mask=mask, other=min_value) 

32 max_val = tl.max(inp_val) 

33 mid_ptr = mid + pid 

34 tl.store(mid_ptr, max_val) 

35 

36 

37@libentry() 

38@triton.jit 

39def max_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr): 

40 offset = tl.arange(0, BLOCK_MID) 

41 mid_ptrs = mid + offset 

42 mask = offset < mid_size 

43 min_value = get_dtype_min(mid.type.element_ty) 

44 mid_val = tl.load(mid_ptrs, mask=mask, other=min_value) 

45 max_val = tl.max(mid_val) 

46 tl.store(out, max_val) 

47 

48 

49def heur_block_n(args): 

50 return triton.next_power_of_2(args["N"]) 

51 

52 

53@libentry() 

54@libtuner( 

55 configs=runtime.get_tuned_config("naive_reduction"), 

56 key=["M", "N"], 

57) 

58@triton.jit 

59def max_kernel( 

60 inp, 

61 out_value, 

62 out_index, 

63 M, 

64 N, 

65 BLOCK_M: tl.constexpr, 

66 BLOCK_N: tl.constexpr, 

67): 

68 # set offset 

69 pid_m = tle.program_id(0) 

70 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

71 

72 dtype = inp.type.element_ty 

73 acc_type = tl.float32 if dtype is tl.bfloat16 else dtype 

74 min_value = get_dtype_min(dtype) 

75 result_value = tl.full([BLOCK_M], value=min_value, dtype=acc_type) 

76 result_index = tl.zeros([BLOCK_M], dtype=tl.int64) 

77 for i in range(0, N, BLOCK_N): 

78 n_offset = i + tl.arange(0, BLOCK_N) 

79 offset = m_offset[:, None] * N + n_offset[None, :] 

80 # set mask 

81 mask = m_offset[:, None] < M and n_offset[None, :] < N 

82 inp_ptrs = inp + offset 

83 inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value) 

84 max_value, max_index = tl.max(inp_vals, axis=1, return_indices=True) 

85 update_mask = max_value > result_value 

86 result_value = tl.where(update_mask, max_value, result_value) 

87 result_index = tl.where(update_mask, i + max_index, result_index) 

88 mask1 = m_offset < M 

89 offset_index = m_offset 

90 out_value_ptrs = out_value + offset_index 

91 out_index_ptrs = out_index + offset_index 

92 

93 tl.store(out_value_ptrs, result_value, mask=mask1) 

94 tl.store(out_index_ptrs, result_index, mask=mask1) 

95 

96 

97def max(inp): 

98 logger.debug("GEMS MAX") 

99 inp = inp.contiguous() 

100 M = inp.numel() 

101 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

102 mid_size = triton.cdiv(M, block_size) 

103 block_mid = triton.next_power_of_2(mid_size) 

104 

105 dtype = inp.dtype 

106 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

107 out = torch.empty([], dtype=dtype, device=inp.device) 

108 

109 with torch_device_fn.device(inp.device): 

110 max_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size) 

111 max_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid) 

112 return out 

113 

114 

115def max_dim(inp, dim=None, keepdim=False): 

116 logger.debug("GEMS MAX DIM") 

117 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

118 shape = list(inp.shape) 

119 dim = dim % inp.ndim 

120 inp = dim_compress(inp, dim) 

121 N = shape[dim] 

122 shape[dim] = 1 

123 M = inp.numel() // N 

124 

125 out_value = torch.empty(shape, dtype=inp.dtype, device=inp.device) 

126 out_index = torch.empty(shape, dtype=torch.int64, device=inp.device) 

127 

128 if not keepdim: 

129 out_value = torch.squeeze(out_value, dim) 

130 out_index = torch.squeeze(out_index, dim) 

131 

132 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

133 with torch_device_fn.device(inp.device): 

134 max_kernel[grid](inp, out_value, out_index, M, N) 

135 Max_out = namedtuple("max", ["values", "indices"]) 

136 out = Max_out(values=out_value, indices=out_index) 

137 return out