Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/mean.py: 0%
92 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1import builtins
2import logging
4import torch
5import triton
6import triton.language as tl
8# from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import dim_compress, libentry
11from flag_gems.utils import triton_lang_extension as tle
13from ..utils.block_size_utils import get_block_size_1d
15logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
18@libentry()
19@triton.jit
20def mean_kernel_1(
21 inp,
22 mid,
23 M,
24 BLOCK_SIZE: tl.constexpr,
25):
26 pid = tle.program_id(0)
27 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
28 inp_ptrs = inp + offset
29 mask = offset < M
30 inp_val = tl.load(inp_ptrs, mask=mask, other=0.0)
31 sum_val = tl.sum(inp_val, axis=0)
32 mid_ptr = mid + pid
33 tl.store(mid_ptr, sum_val)
36@libentry()
37@triton.jit
38def mean_kernel_2(mid, out, M, MID_SIZE, BLOCK_MID: tl.constexpr):
39 offset = tl.arange(0, BLOCK_MID)
40 mid_ptrs = mid + offset
41 mask = offset < MID_SIZE
42 mid_val = tl.load(mid_ptrs, mask=mask, other=0.0)
43 sum_val = tl.sum(mid_val, axis=0) / M
44 tl.store(out, sum_val)
47def mean(inp, *, dtype=None):
48 logger.debug("GEMS MEAN")
49 M = inp.numel()
50 if dtype is None:
51 dtype = inp.dtype
52 # block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
53 block_size = get_block_size_1d(M, inp.element_size())
54 mid_size = triton.cdiv(M, block_size)
55 block_mid = triton.next_power_of_2(mid_size)
57 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)
58 out = torch.empty([], dtype=dtype, device=inp.device)
60 with torch_device_fn.device(inp.device):
61 mean_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size, buffer_size_limit=2048)
62 if mid_size == 1:
63 return (mid / M).reshape([])
64 mean_kernel_2[(1, 1, 1)](
65 mid, out, M, mid_size, block_mid, buffer_size_limit=2048
66 )
67 return out
70def heur_m_block_size(args):
71 return triton.next_power_of_2(triton.cdiv(args["M"], 12)) # cluster_num
74def heur_n_block_size(args):
75 return builtins.min(args["N"], 8192)
78@libentry()
79# @triton.autotune(
80# configs=runtime.get_tuned_config("mean"),
81# key=["M", "N"],
82# )
83@triton.heuristics(
84 values={
85 "BLOCK_M": heur_m_block_size,
86 "BLOCK_N": heur_n_block_size,
87 },
88)
89@triton.jit
90def mean_dim_kernel(X, Mean, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
91 # Map the program id to the row of X it should compute.
92 pid = tle.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
93 X = X + pid * N
94 Mean = Mean + pid
95 row_mask = pid < M
97 # Compute mean
98 _mean = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
99 for off in range(0, N, BLOCK_N):
100 cols = off + tl.arange(0, BLOCK_N)[None, :]
101 col_mask = cols < N
102 mask = row_mask and col_mask
104 a = tl.load(X + cols, mask, other=0.0).to(tl.float32)
105 _mean += a
106 mean = tl.sum(_mean, axis=1) / N
107 mean = mean[:, None]
108 tl.store(Mean, mean, row_mask)
111def mean_dim(x, dim, keepdim=False, *, dtype=None):
112 logger.debug("GEMS MEAN DIM")
114 if dtype is None:
115 dtype = x.dtype
116 if dim is None:
117 out = mean(x, dtype=dtype)
118 if not keepdim:
119 out = out.reshape([1] * x.ndim)
120 return out
122 shape = list(x.shape)
123 dim = [d % x.ndim for d in dim]
124 x = dim_compress(x, dim)
125 N = 1
126 for i in dim:
127 N *= shape[i]
128 shape[i] = 1
129 M = x.numel() // N
130 out = torch.empty(shape, dtype=dtype, device=x.device)
131 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),)
133 with torch_device_fn.device(x.device):
134 mean_dim_kernel[grid](x, out, M, N, buffer_size_limit=2048)
135 if not keepdim:
136 out = out.squeeze(dim)
137 return out