Coverage for src/flag_gems/runtime/backend/_ascend/ops/mean.py: 0%

78 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-25 02:48 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import dim_compress, libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

14 

15 

16@libentry() 

17@triton.jit 

18def mean_kernel_1( 

19 inp, 

20 out, 

21 M, 

22 BLOCK_SIZE: tl.constexpr, 

23): 

24 pid = tl.program_id(0) 

25 num_jobs = tl.num_programs(axis=0) 

26 block_start = pid * BLOCK_SIZE 

27 step = num_jobs * BLOCK_SIZE 

28 _tmp = tl.zeros([BLOCK_SIZE], dtype=tl.float32) 

29 block_start = block_start.to(tl.int64) 

30 for off in range(block_start, M, step): 

31 offset = off + tl.arange(0, BLOCK_SIZE) 

32 mask = offset < M 

33 inp_val = tl.load(inp + offset, mask=mask, other=0.0) 

34 _tmp = inp_val + _tmp 

35 

36 mean_val = tl.sum(_tmp, axis=0) / M 

37 tl.atomic_add(out, mean_val) 

38 

39 

40def mean(inp, *, dtype=None): 

41 logger.debug("GEMS MEAN") 

42 M = inp.numel() 

43 if dtype is None: 

44 dtype = inp.dtype 

45 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

46 out = torch.zeros([], dtype=dtype, device=inp.device) 

47 

48 with torch_device_fn.device(inp.device): 

49 mean_kernel_1[(triton.cdiv(M, block_size), 1, 1)](inp, out, M, block_size) 

50 # mean_kernel_2[(1, 1, 1)](mid, out, M, mid_size, block_mid) 

51 return out 

52 

53 

54@libentry() 

55@triton.autotune( 

56 configs=runtime.get_tuned_config("mean"), 

57 key=["M", "N"], 

58) 

59@triton.jit 

60def mean_dim_kernel(X, Mean, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): 

61 # Map the program id to the row of X it should compute. 

62 pid = tle.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

63 X = X + pid * N 

64 Mean = Mean + pid 

65 row_mask = pid < M 

66 

67 # Compute mean 

68 _mean = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

69 for off in range(0, N, BLOCK_N): 

70 cols = off + tl.arange(0, BLOCK_N)[None, :] 

71 col_mask = cols < N 

72 mask = row_mask and col_mask 

73 

74 a = tl.load(X + cols, mask, other=0.0).to(tl.float32) 

75 _mean += a 

76 mean = tl.sum(_mean, axis=1) / N 

77 mean = mean[:, None] 

78 tl.store(Mean, mean, row_mask) 

79 

80 

81def mean_dim(x, dim, keepdim=False, *, dtype=None): 

82 logger.debug("GEMS MEAN DIM") 

83 

84 if dtype is None: 

85 dtype = x.dtype 

86 if dim is None: 

87 out = mean(x, dtype=dtype) 

88 if not keepdim: 

89 out = out.reshape([1] * x.ndim) 

90 return out 

91 

92 shape = list(x.shape) 

93 dim = [d % x.ndim for d in dim] 

94 x = dim_compress(x, dim) 

95 N = 1 

96 for i in dim: 

97 N *= shape[i] 

98 shape[i] = 1 

99 M = x.numel() // N 

100 out = torch.empty(shape, dtype=dtype, device=x.device) 

101 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),) 

102 

103 with torch_device_fn.device(x.device): 

104 mean_dim_kernel[grid](x, out, M, N) 

105 if not keepdim: 

106 out = out.squeeze(dim) 

107 return out