Coverage for src/flag_gems/runtime/backend/_cambricon/ops/mean.py: 0%

82 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-11 02:28 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import dim_compress, libentry, libtuner 

10 

11from ..utils import TOTAL_CORE_NUM, cfggen_reduce_op 

12 

13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

14 

15 

16@libentry() 

17@libtuner( 

18 configs=cfggen_reduce_op(), key=["M"], strategy=["log"], reset_to_zero=["out"] 

19) 

20@triton.jit 

21def mean_kernel_1( 

22 inp, 

23 out, 

24 M, 

25 BLOCK_SIZE: tl.constexpr, 

26): 

27 pid = tl.program_id(0) 

28 num_jobs = tl.num_programs(axis=0) 

29 block_start = pid * BLOCK_SIZE 

30 step = num_jobs * BLOCK_SIZE 

31 _tmp = tl.zeros([BLOCK_SIZE], dtype=tl.float32) 

32 block_start = block_start.to(tl.int64) 

33 for off in range(block_start, M, step): 

34 offset = off + tl.arange(0, BLOCK_SIZE) 

35 mask = offset < M 

36 inp_val = tl.load(inp + offset, mask=mask, other=0.0) 

37 _tmp = inp_val + _tmp 

38 

39 mean_val = tl.sum(_tmp, axis=0) / M 

40 tl.atomic_add(out, mean_val) 

41 

42 

43def mean(inp, *, dtype=None): 

44 logger.debug("GEMS_CAMBRICON MEAN") 

45 M = inp.numel() 

46 if dtype is None: 

47 dtype = inp.dtype 

48 grid = lambda meta: (min(triton.cdiv(M, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) 

49 out = torch.zeros([], dtype=torch.float32, device=inp.device) 

50 

51 with torch_device_fn.device(inp.device): 

52 mean_kernel_1[grid](inp, out, M) 

53 return out.to(dtype) 

54 

55 

56@libentry() 

57@libtuner( 

58 configs=runtime.get_tuned_config("mean"), 

59 key=["M", "N"], 

60 strategy=["log", "log"], 

61) 

62@triton.jit 

63def mean_dim_kernel(X, Mean, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): 

64 # Map the program id to the row of X it should compute. 

65 num_prog = tl.num_programs(0) 

66 task_num = tl.cdiv(M, BLOCK_M) 

67 iter_num = tl.cdiv(task_num, num_prog) 

68 for i in range(0, iter_num): 

69 pid = (i * num_prog + tl.program_id(0)) * BLOCK_M + tl.arange(0, BLOCK_M)[ 

70 :, None 

71 ] 

72 X_ptr = X + pid * N 

73 Mean_ptr = Mean + pid 

74 row_mask = pid < M 

75 

76 # Compute mean 

77 _mean = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

78 for off in range(0, N, BLOCK_N): 

79 cols = off + tl.arange(0, BLOCK_N)[None, :] 

80 col_mask = cols < N 

81 mask = row_mask and col_mask 

82 

83 a = tl.load(X_ptr + cols, mask, other=0.0).to(tl.float32) 

84 _mean += a 

85 _mean /= N 

86 mean = tl.sum(_mean, axis=1)[:, None] 

87 tl.store(Mean_ptr, mean, row_mask) 

88 

89 

90def mean_dim(x, dim, keepdim=False, *, dtype=None): 

91 logger.debug("GEMS_CAMBRICON MEAN DIM") 

92 

93 if dtype is None: 

94 dtype = x.dtype 

95 if dim is None: 

96 out = mean(x, dtype=dtype) 

97 if not keepdim: 

98 out = out.reshape([1] * x.ndim) 

99 return out 

100 

101 shape = list(x.shape) 

102 dim = [d % x.ndim for d in dim] 

103 x = dim_compress(x, dim) 

104 N = 1 

105 for i in dim: 

106 N *= shape[i] 

107 shape[i] = 1 

108 M = x.numel() // N 

109 out = torch.empty(shape, dtype=dtype, device=x.device) 

110 grid = lambda META: (min(triton.cdiv(M, META["BLOCK_M"]), TOTAL_CORE_NUM),) 

111 with torch_device_fn.device(x.device): 

112 mean_dim_kernel[grid](x, out, M, N) 

113 if not keepdim: 

114 out = out.squeeze(dim) 

115 return out