Coverage for src/flag_gems/runtime/backend/_cambricon/ops/mv.py: 0%

67 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-11 02:28 +0800

1import copy 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11 

12from ..utils import MAX_NRAM_SIZE 

13 

14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

15 

16 

17def config_prune(configs, named_args, **kwargs): 

18 M = named_args["M"] 

19 configs_map = {} 

20 pruned_configs = [] 

21 for config in configs: 

22 kw = config.kwargs 

23 BLOCK_M, BLOCK_N, num_warps, num_stages = ( 

24 kw["BLOCK_M"], 

25 kw["BLOCK_N"], 

26 config.num_warps, 

27 config.num_stages, 

28 ) 

29 doopt = BLOCK_N * M * 4 * 3 < MAX_NRAM_SIZE 

30 if doopt: 

31 config = copy.deepcopy(config) 

32 BLOCK_M = config.kwargs["BLOCK_M"] = M 

33 num_stages = config.num_stages = 1 

34 elif BLOCK_M >= M: 

35 continue 

36 key = (BLOCK_M, BLOCK_N, num_warps, num_stages) 

37 # Only keep one config for the same key 

38 configs_map.setdefault(key, config) 

39 pruned_configs = [] 

40 for k, v in configs_map.items(): 

41 pruned_configs.append(v) 

42 return pruned_configs 

43 

44 

45@libentry() 

46@triton.autotune( 

47 configs=runtime.get_tuned_config("mv"), 

48 key=["M", "N"], 

49 prune_configs_by={"early_config_prune": config_prune}, 

50) 

51@triton.heuristics( 

52 values={ 

53 "ONE_TILE_PER_CTA": lambda args: args["M"] <= args["BLOCK_M"], 

54 }, 

55) 

56@triton.jit 

57def mv_kernel( 

58 A, 

59 B, 

60 C, 

61 N, 

62 M, 

63 stride_an, 

64 stride_am, 

65 stride_bm, 

66 stride_cn, 

67 BLOCK_N: tl.constexpr, 

68 BLOCK_M: tl.constexpr, 

69 ONE_TILE_PER_CTA: tl.constexpr, 

70): 

71 pid = tl.program_id(0) 

72 offset_n = pid * BLOCK_N + tl.arange(0, BLOCK_N)[:, None] 

73 offset_m = tl.arange(0, BLOCK_M)[None, :] 

74 n_mask = offset_n < N 

75 A_ptrs = A + offset_n * stride_an + offset_m * stride_am 

76 B_ptrs = B + offset_m * stride_bm 

77 if ONE_TILE_PER_CTA: 

78 a = tl.load(A_ptrs, mask=n_mask, other=0.0).to(tl.float32) 

79 b = tl.load(B_ptrs).to(tl.float32) 

80 acc = tl.sum(a * b, axis=1) 

81 C_ptrs = C + offset_n * stride_cn 

82 tl.store(C_ptrs, acc[:, None], mask=n_mask) 

83 else: 

84 acc = tl.zeros((BLOCK_N, BLOCK_M), dtype=tl.float32) 

85 for m in range(0, M, BLOCK_M): 

86 m_mask = m + offset_m < M 

87 a = tl.load(A_ptrs, mask=n_mask & m_mask, other=0.0).to(tl.float32) 

88 b = tl.load(B_ptrs, mask=m_mask, other=0.0).to(tl.float32) 

89 acc += a * b 

90 A_ptrs += BLOCK_M * stride_am 

91 B_ptrs += BLOCK_M * stride_bm 

92 acc = tl.sum(acc, axis=1) 

93 C_ptrs = C + offset_n * stride_cn 

94 tl.store(C_ptrs, acc[:, None], mask=n_mask) 

95 

96 

97def mv(inp, vec): 

98 logger.debug("GEMS_CAMBRICON MV") 

99 assert inp.shape[1] == vec.shape[0], "incompatible dimensions" 

100 N, M = inp.shape 

101 out = torch.empty((N,), device=inp.device, dtype=inp.dtype) 

102 grid = lambda META: (triton.cdiv(N, META["BLOCK_N"]),) 

103 with torch_device_fn.device(inp.device): 

104 mv_kernel[grid]( 

105 inp, 

106 vec, 

107 out, 

108 N, 

109 M, 

110 inp.stride(0), 

111 inp.stride(1), 

112 vec.stride(0), 

113 out.stride(0), 

114 ) 

115 return out