Coverage for src/flag_gems/runtime/backend/_ascend/ops/amax.py: 0%

100 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-20 02:31 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import dim_compress, libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12from flag_gems.utils.limits import get_dtype_min 

13 

14logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

15 

16 

17@libentry() 

18@triton.jit 

19def amax_kernel_1( 

20 inp, 

21 mid, 

22 M, 

23 BLOCK_SIZE: tl.constexpr, 

24): 

25 pid = tle.program_id(0) 

26 

27 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

28 inp_ptrs = inp + offset 

29 mask = offset < M 

30 min_value = get_dtype_min(inp.type.element_ty) 

31 inp_val = tl.load(inp_ptrs, mask=mask, other=min_value) 

32 amax_val = tl.max(inp_val) 

33 mid_ptr = mid + pid 

34 tl.store(mid_ptr, amax_val) 

35 

36 

37@libentry() 

38@triton.jit 

39def amax_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr): 

40 offset = tl.arange(0, BLOCK_MID) 

41 mid_ptrs = mid + offset 

42 mask = offset < mid_size 

43 min_value = get_dtype_min(mid.type.element_ty) 

44 mid_val = tl.load(mid_ptrs, mask=mask, other=min_value) 

45 amax_val = tl.max(mid_val) 

46 tl.store(out, amax_val) 

47 

48 

49@libentry() 

50@triton.autotune(configs=runtime.get_tuned_config("amax"), key=["M", "N"]) 

51@triton.jit 

52def amax_kernel( 

53 inp, 

54 out, 

55 M, 

56 N, 

57 BLOCK_M: tl.constexpr, 

58 BLOCK_N: tl.constexpr, 

59): 

60 dtype = inp.type.element_ty 

61 min_value = get_dtype_min(dtype) 

62 

63 # Map the program id to the row of inp it should compute. 

64 workers = tle.num_programs(0) 

65 pid = tle.program_id(0) 

66 

67 total_workloads = tl.cdiv(M, BLOCK_M) 

68 workloads = tl.cdiv(total_workloads, workers) 

69 

70 for w in range(workloads): 

71 work_id = pid + w * workers 

72 rows = work_id * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

73 ninp = inp + rows * N 

74 nout = out + rows 

75 row_mask = rows < M 

76 

77 acc_type = tl.float32 if dtype is tl.bfloat16 else dtype 

78 _all = tl.full([BLOCK_M, BLOCK_N], value=min_value, dtype=acc_type) 

79 for off in range(0, N, BLOCK_N): 

80 cols = off + tl.arange(0, BLOCK_N)[None, :] 

81 col_mask = cols < N 

82 mask = row_mask and col_mask 

83 a = tl.load(ninp + cols, mask, other=min_value) 

84 _all = tl.maximum(_all, a) 

85 all = tl.max(_all, axis=1)[:, None] 

86 tl.store(nout, all, row_mask) 

87 

88 

89def amax(inp, dim=None, keepdim=False): 

90 logger.debug("GEMS_ASCEND AMAX") 

91 if dim is None or len(dim) == 0: 

92 M = inp.numel() 

93 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

94 mid_size = triton.cdiv(M, block_size) 

95 block_mid = triton.next_power_of_2(mid_size) 

96 dtype = inp.dtype 

97 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

98 if not keepdim: 

99 out = torch.empty([], dtype=dtype, device=inp.device) 

100 else: 

101 shape = list(inp.shape) 

102 for i in range(0, inp.dim()): 

103 shape[i] = 1 

104 out = torch.empty(shape, dtype=dtype, device=inp.device) 

105 with torch_device_fn.device(inp.device): 

106 amax_kernel_1[(mid_size, 1)]( 

107 inp, 

108 mid, 

109 M, 

110 block_size, 

111 ) 

112 amax_kernel_2[(1, 1)]( 

113 mid, out, mid_size, block_mid 

114 ) # max block size is 128k, so mid does not requires int64 index 

115 return out 

116 else: 

117 if isinstance(dim, int): 

118 dim = [dim] 

119 assert ((i >= -inp.ndim and i < inp.ndim) for i in dim), "Invalid dim" 

120 dtype = inp.dtype 

121 

122 shape = list(inp.shape) 

123 dim = [d % inp.ndim for d in dim] 

124 inp = dim_compress(inp, dim) 

125 N = 1 

126 for i in dim: 

127 N *= shape[i] 

128 shape[i] = 1 

129 M = inp.numel() // N 

130 

131 out = torch.empty(shape, dtype=dtype, device=inp.device) 

132 

133 def grid(meta): 

134 axis0 = triton.cdiv(M, meta["BLOCK_M"]) 

135 axis0 = axis0 if axis0 < 4096 else 4096 

136 return (axis0,) 

137 

138 with torch_device_fn.device(inp.device): 

139 amax_kernel[grid](inp, out, M, N) 

140 if not keepdim: 

141 out = out.squeeze(dim=dim) 

142 return out