Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/mv.py: 0%

62 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import logging 

2import os 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8# from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13from .mm import mm 

14 

15logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

16 

17 

18def heur_block_n(args): 

19 return triton.next_power_of_2(triton.cdiv(args["N"], 12)) 

20 

21 

22def heur_block_m(args): 

23 import builtins 

24 

25 return builtins.min(triton.next_power_of_2(args["M"]), 4096) 

26 

27 

28@libentry() 

29# @triton.autotune( 

30# configs=runtime.get_tuned_config("mv"), 

31# key=["M", "N"], 

32# ) 

33@triton.heuristics( 

34 { 

35 "BLOCK_N": heur_block_n, 

36 "BLOCK_M": heur_block_m, 

37 } 

38) 

39@triton.jit 

40def mv_kernel( 

41 A, 

42 B, 

43 C, 

44 N: tl.constexpr, 

45 M: tl.constexpr, 

46 stride_an: tl.constexpr, 

47 stride_am: tl.constexpr, 

48 stride_bm: tl.constexpr, 

49 stride_cn: tl.constexpr, 

50 BLOCK_N: tl.constexpr, 

51 BLOCK_M: tl.constexpr, 

52 buffer_size_limit: tl.constexpr, # NOTE: `constexpr` so it can be used as a shape value. 

53): 

54 pid = tle.program_id(0) 

55 offset_n = pid * BLOCK_N + tl.arange(0, BLOCK_N)[:, None] 

56 offset_m = tl.arange(0, BLOCK_M)[None, :] 

57 n_mask = offset_n < N 

58 A_ptrs = A + offset_n * stride_an + offset_m * stride_am 

59 B_ptrs = B + offset_m * stride_bm 

60 acc = tl.zeros((BLOCK_N, BLOCK_M), dtype=tl.float32) 

61 for m in range(0, M, BLOCK_M): 

62 m_mask = m + offset_m < M 

63 a = tl.load(A_ptrs, mask=n_mask & m_mask, other=0.0).to(tl.float32) 

64 b = tl.load(B_ptrs, mask=m_mask, other=0.0).to(tl.float32) 

65 acc += a * b 

66 A_ptrs += BLOCK_M * stride_am 

67 B_ptrs += BLOCK_M * stride_bm 

68 

69 acc = tl.sum(acc, axis=1) 

70 C_ptrs = C + offset_n * stride_cn 

71 tl.store(C_ptrs, acc[:, None], mask=n_mask) 

72 

73 

74def mv(inp, vec): 

75 logger.debug("GEMS MV") 

76 assert inp.shape[1] == vec.shape[0], "incompatible dimensions" 

77 N, M = inp.shape 

78 # TODO: fix autotune config has no item 

79 if M == 5333 and N == 497: 

80 return mv_cluster(inp, vec) 

81 

82 out = torch.empty((N,), device=inp.device, dtype=inp.dtype) 

83 grid = lambda META: (triton.cdiv(N, META["BLOCK_N"]),) 

84 with torch_device_fn.device(inp.device): 

85 if M == 1: 

86 mv_kernel[grid]( 

87 inp, 

88 vec, 

89 out, 

90 N, 

91 M, 

92 inp.stride(0), 

93 inp.stride(1), 

94 vec.stride(0), 

95 out.stride(0), 

96 buffer_size_limit=256, 

97 ) 

98 else: 

99 os.environ["XMLIR_MATMUL_FAST_MODE"] = "1" 

100 vec = vec[:, None] 

101 out = mm(inp, vec) 

102 out = out.squeeze() 

103 del os.environ["XMLIR_MATMUL_FAST_MODE"] 

104 return out 

105 

106 

107def mv_cluster(inp, vec): 

108 logger.debug("GEMS MV") 

109 assert inp.shape[1] == vec.shape[0], "incompatible dimensions" 

110 N, M = inp.shape 

111 out = torch.empty((N,), device=inp.device, dtype=inp.dtype) 

112 grid = lambda META: (triton.cdiv(N, META["BLOCK_N"]),) 

113 with torch_device_fn.device(inp.device): 

114 mv_kernel[grid]( 

115 inp, 

116 vec, 

117 out, 

118 N, 

119 M, 

120 inp.stride(0), 

121 inp.stride(1), 

122 vec.stride(0), 

123 out.stride(0), 

124 buffer_size_limit=256, 

125 ) 

126 return out