Coverage for src/flag_gems/runtime/backend/_ascend/ops/mm.py: 0%

76 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-25 02:48 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry, libtuner 

10 

11logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

12 

13 

14@libentry() 

15@libtuner( 

16 configs=runtime.get_tuned_config("mm"), 

17 key=["M", "N", "K"], 

18) 

19@triton.heuristics(runtime.get_heuristic_config("mm")) 

20@triton.jit 

21def mm_kernel( 

22 A, 

23 B, 

24 C, 

25 M: tl.constexpr, 

26 N: tl.constexpr, 

27 K: tl.constexpr, 

28 stride_am: tl.constexpr, 

29 stride_ak: tl.constexpr, 

30 stride_bk: tl.constexpr, 

31 stride_bn: tl.constexpr, 

32 stride_cm: tl.constexpr, 

33 stride_cn: tl.constexpr, 

34 dot_out_dtype: tl.constexpr, 

35 BLOCK_M: tl.constexpr, 

36 BLOCK_N: tl.constexpr, 

37 BLOCK_K: tl.constexpr, 

38 GROUP_M: tl.constexpr, 

39 EVEN_K: tl.constexpr, 

40): 

41 # matrix multiplication 

42 pid = tl.program_id(0) 

43 pid_z = tl.program_id(1) 

44 grid_m = tl.cdiv(M, BLOCK_M) 

45 grid_n = tl.cdiv(N, BLOCK_N) 

46 # re-order program ID for better L2 performance 

47 width = GROUP_M * grid_n 

48 group_id = pid // width 

49 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 

50 pid_m = group_id * GROUP_M + (pid % group_size) 

51 pid_n = (pid % width) // (group_size) 

52 # do matrix multiplication 

53 ram = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

54 rbn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

55 # ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) 

56 # rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) 

57 rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) 

58 # pointers 

59 A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) 

60 B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) 

61 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype) 

62 for k in range(0, tl.cdiv(K, BLOCK_K)): 

63 if EVEN_K: 

64 a = tl.load(A) 

65 b = tl.load(B) 

66 else: 

67 k_remaining = K - k * (BLOCK_K) 

68 _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty) 

69 a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0) 

70 b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0) 

71 if a.dtype != b.dtype: 

72 a = a.to(C.dtype.element_ty) 

73 b = b.to(C.dtype.element_ty) 

74 acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=False) 

75 A += BLOCK_K * stride_ak 

76 B += BLOCK_K * stride_bk 

77 acc = acc.to(C.dtype.element_ty) 

78 # rematerialize rm and rn to save registers 

79 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

80 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

81 C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) 

82 mask = (rm < M)[:, None] & (rn < N)[None, :] 

83 # handles write-back with reduction-splitting 

84 tl.store(C, acc, mask=mask) 

85 

86 

87_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32] 

88 

89 

90def get_higher_dtype(a, b): 

91 if a is b: 

92 return a 

93 

94 assert a in _ordered_datatypes 

95 assert b in _ordered_datatypes 

96 

97 for d in _ordered_datatypes: 

98 if a is d: 

99 return b 

100 if b is d: 

101 return a 

102 

103 

104def mm(a, b): 

105 logger.debug("GEMS_ASCEND MM") 

106 device = a.device 

107 # handle non-contiguous inputs if necessary 

108 if a.stride(0) > 1 and a.stride(1) > 1: 

109 a = a.contiguous() 

110 if b.stride(0) > 1 and b.stride(1) > 1: 

111 b = b.contiguous() 

112 # checks constraints 

113 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

114 M, K = a.shape 

115 _, N = b.shape 

116 # allocates output 

117 c_dtype = get_higher_dtype(a.dtype, b.dtype) 

118 c = torch.empty((M, N), device=device, dtype=c_dtype) 

119 dot_out_dtype = tl.float32 

120 # launch kernel 

121 grid = lambda META: ( 

122 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

123 ) 

124 with torch_device_fn.device(a.device): 

125 mm_kernel[grid]( 

126 a, 

127 b, 

128 c, 

129 M, 

130 N, 

131 K, 

132 a.stride(0), 

133 a.stride(1), 

134 b.stride(0), 

135 b.stride(1), 

136 c.stride(0), 

137 c.stride(1), 

138 dot_out_dtype=dot_out_dtype, 

139 GROUP_M=8, 

140 ) 

141 return c