Coverage for src/flag_gems/runtime/backend/_ascend/ops/min.py: 0%

101 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-21 14:31 +0800

1import logging 

2import math 

3from collections import namedtuple 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems import runtime 

10from flag_gems.runtime import torch_device_fn 

11from flag_gems.utils import libentry 

12from flag_gems.utils import triton_lang_extension as tle 

13from flag_gems.utils.limits import get_dtype_max 

14 

15logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

16 

17 

18@libentry() 

19@triton.jit 

20def min_kernel_1( 

21 inp, 

22 mid, 

23 M, 

24 BLOCK_SIZE: tl.constexpr, 

25): 

26 pid = tle.program_id(0) 

27 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

28 inp_ptrs = inp + offset 

29 mask = offset < M 

30 max_value = get_dtype_max(inp.type.element_ty) 

31 inp_val = tl.load(inp_ptrs, mask=mask, other=max_value) 

32 min_val = tl.min(inp_val) 

33 mid_ptr = mid + pid 

34 tl.store(mid_ptr, min_val) 

35 

36 

37@libentry() 

38@triton.jit 

39def min_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr): 

40 offset = tl.arange(0, BLOCK_MID) 

41 mid_ptrs = mid + offset 

42 mask = offset < mid_size 

43 max_value = get_dtype_max(mid.type.element_ty) 

44 mid_val = tl.load(mid_ptrs, mask=mask, other=max_value) 

45 min_val = tl.min(mid_val) 

46 tl.store(out, min_val) 

47 

48 

49def heur_block_n(args): 

50 return triton.next_power_of_2(args["N"]) 

51 

52 

53@libentry() 

54@triton.autotune( 

55 configs=runtime.get_tuned_config("min"), 

56 key=[ 

57 "M", 

58 "N", 

59 ], 

60) 

61@triton.jit 

62def min_kernel( 

63 inp, 

64 out_value, 

65 out_index, 

66 M, 

67 N, 

68 K, 

69 BLOCK_M: tl.constexpr, 

70 BLOCK_N: tl.constexpr, 

71): 

72 # set offset 

73 pid_m = tle.program_id(0) 

74 pid_k = tle.program_id(1) 

75 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

76 

77 dtype = inp.type.element_ty 

78 # you just cannot create a function that return a tl.dtype in triton lang 

79 acc_type = tl.float32 if dtype is tl.bfloat16 else dtype 

80 max_value = get_dtype_max(dtype) 

81 min_values = tl.full([BLOCK_M], dtype=acc_type, value=max_value) 

82 argmin_values = tl.full([BLOCK_M], dtype=tl.int64, value=0) 

83 for start_n in range(0, N, BLOCK_N): 

84 n_offset = start_n + tl.arange(0, BLOCK_N) 

85 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k 

86 mask = m_offset[:, None] < M and n_offset[None, :] < N 

87 inp_ptrs = inp + offset 

88 inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value) 

89 if dtype is tl.int64: 

90 inp_vals = tl.where(mask, inp_vals, max_value) 

91 local_min, local_argmin = tl.min(inp_vals, 1, return_indices=True) 

92 # if return indices is not supported, call a tl.argmax in addition 

93 # local_argmin = tl.argmin(inp_vals, 1) 

94 update = local_min < min_values 

95 min_values = tl.where(update, local_min, min_values) 

96 argmin_values = tl.where(update, start_n + local_argmin, argmin_values) 

97 

98 offset_index = m_offset * K + pid_k 

99 out_value_ptrs = out_value + offset_index 

100 out_index_ptrs = out_index + offset_index 

101 mask1 = m_offset < M 

102 tl.store(out_value_ptrs, min_values, mask=mask1) 

103 tl.store(out_index_ptrs, argmin_values, mask=mask1) 

104 

105 

106def min(inp): 

107 logger.debug("GEMS_ASCEND MIN") 

108 M = inp.numel() 

109 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

110 mid_size = triton.cdiv(M, block_size) 

111 block_mid = triton.next_power_of_2(mid_size) 

112 

113 dtype = inp.dtype 

114 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

115 out = torch.empty([], dtype=dtype, device=inp.device) 

116 

117 with torch_device_fn.device(inp.device): 

118 min_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size) 

119 min_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid) 

120 return out 

121 

122 

123def min_dim(inp, dim=None, keepdim=False): 

124 logger.debug("GEMS_ASCEND MIN DIM") 

125 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

126 shape = inp.shape 

127 dim = dim % inp.ndim 

128 N = shape[dim] 

129 M = math.prod(shape[:dim]) 

130 K = inp.numel() // M // N 

131 

132 inp = inp.contiguous() 

133 

134 shape_list = list(shape) 

135 shape_list[dim] = 1 

136 out_value = torch.empty(shape_list, dtype=inp.dtype, device=inp.device) 

137 out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device) 

138 

139 if not keepdim: 

140 out_value = torch.squeeze(out_value, dim) 

141 out_index = torch.squeeze(out_index, dim) 

142 

143 grid = lambda meta: ( 

144 triton.cdiv(M, meta["BLOCK_M"]), 

145 K, 

146 ) 

147 with torch_device_fn.device(inp.device): 

148 min_kernel[grid](inp, out_value, out_index, M, N, K) 

149 Min_out = namedtuple("min", ["values", "indices"]) 

150 out = Min_out(values=out_value, indices=out_index) 

151 return out