Coverage for src/flag_gems/runtime/backend/_ascend/ops/mean.py: 0%
78 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import dim_compress, libentry
11from flag_gems.utils import triton_lang_extension as tle
13logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
16@libentry()
17@triton.jit
18def mean_kernel_1(
19 inp,
20 out,
21 M,
22 BLOCK_SIZE: tl.constexpr,
23):
24 pid = tl.program_id(0)
25 num_jobs = tl.num_programs(axis=0)
26 block_start = pid * BLOCK_SIZE
27 step = num_jobs * BLOCK_SIZE
28 _tmp = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
29 block_start = block_start.to(tl.int64)
30 for off in range(block_start, M, step):
31 offset = off + tl.arange(0, BLOCK_SIZE)
32 mask = offset < M
33 inp_val = tl.load(inp + offset, mask=mask, other=0.0)
34 _tmp = inp_val + _tmp
36 mean_val = tl.sum(_tmp, axis=0) / M
37 tl.atomic_add(out, mean_val)
40def mean(inp, *, dtype=None):
41 logger.debug("GEMS MEAN")
42 M = inp.numel()
43 if dtype is None:
44 dtype = inp.dtype
45 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
46 out = torch.zeros([], dtype=dtype, device=inp.device)
48 with torch_device_fn.device(inp.device):
49 mean_kernel_1[(triton.cdiv(M, block_size), 1, 1)](inp, out, M, block_size)
50 # mean_kernel_2[(1, 1, 1)](mid, out, M, mid_size, block_mid)
51 return out
54@libentry()
55@triton.autotune(
56 configs=runtime.get_tuned_config("mean"),
57 key=["M", "N"],
58)
59@triton.jit
60def mean_dim_kernel(X, Mean, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
61 # Map the program id to the row of X it should compute.
62 pid = tle.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
63 X = X + pid * N
64 Mean = Mean + pid
65 row_mask = pid < M
67 # Compute mean
68 _mean = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
69 for off in range(0, N, BLOCK_N):
70 cols = off + tl.arange(0, BLOCK_N)[None, :]
71 col_mask = cols < N
72 mask = row_mask and col_mask
74 a = tl.load(X + cols, mask, other=0.0).to(tl.float32)
75 _mean += a
76 mean = tl.sum(_mean, axis=1) / N
77 mean = mean[:, None]
78 tl.store(Mean, mean, row_mask)
81def mean_dim(x, dim, keepdim=False, *, dtype=None):
82 logger.debug("GEMS MEAN DIM")
84 if dtype is None:
85 dtype = x.dtype
86 if dim is None:
87 out = mean(x, dtype=dtype)
88 if not keepdim:
89 out = out.reshape([1] * x.ndim)
90 return out
92 shape = list(x.shape)
93 dim = [d % x.ndim for d in dim]
94 x = dim_compress(x, dim)
95 N = 1
96 for i in dim:
97 N *= shape[i]
98 shape[i] = 1
99 M = x.numel() // N
100 out = torch.empty(shape, dtype=dtype, device=x.device)
101 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),)
103 with torch_device_fn.device(x.device):
104 mean_dim_kernel[grid](x, out, M, N)
105 if not keepdim:
106 out = out.squeeze(dim)
107 return out