Coverage for src/flag_gems/ops/var_mean.py: 49%
115 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import dim_compress, libentry
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger(__name__)
15@triton.jit
16def welford_func(mean_x, count_x, M_x, mean_y, count_y, M_y):
17 count = count_x + count_y
18 _count = tl.maximum(count, 1)
19 mc_x = mean_x * count_x
20 mc_y = mean_y * count_y
21 mean = (mc_x + mc_y) / _count
22 M = M_x + mc_x * mean_x + M_y + mc_y * mean_y - count * mean * mean
23 return mean, count, M
26@libentry()
27@triton.autotune(configs=runtime.get_tuned_config("var_mean"), key=["M", "N"])
28@triton.jit(do_not_specialize=["correction"])
29def var_mean_welford_kernel(
30 X,
31 Var,
32 Mean,
33 M,
34 N,
35 correction,
36 BLOCK_M: tl.constexpr,
37 BLOCK_N: tl.constexpr,
38):
39 # Map the program id to the row of X it should compute.
40 pid = tle.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
41 X = X + pid * N
42 Var = Var + pid
43 Mean = Mean + pid
44 row_mask = pid < M
46 _mean = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
47 _acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
48 _count = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
49 for off in range(0, N, BLOCK_N):
50 cols = off + tl.arange(0, BLOCK_N)[None, :]
51 col_mask = cols < N
52 mask = row_mask and col_mask
54 x = tl.load(X + cols, mask, other=0.0).to(tl.float32)
56 count = _count + mask
57 cnt = tl.maximum(count, 1)
58 cur_mean = (_mean * _count + x) / cnt
59 _acc += (x - cur_mean) * (x - _mean) * mask
60 _mean = cur_mean
61 _count = count
63 mean, _, acc = tl.reduce((_mean, _count, _acc), axis=1, combine_fn=welford_func)
64 var = acc / (N - correction)
65 mean = mean[:, None]
66 var = var[:, None]
67 # Write mean / var
68 tl.store(Mean, mean, row_mask)
69 tl.store(Var, var, row_mask)
72@libentry()
73@triton.jit
74def var_mean_kernel_1(
75 X,
76 Acc,
77 Average,
78 Count,
79 N,
80 BLOCK_N: tl.constexpr,
81):
82 # Map the program id to the row of X it should compute.
83 pid = tle.program_id(0)
84 offset = pid * BLOCK_N + tl.arange(0, BLOCK_N)
86 X = X + offset
87 Acc = Acc + pid
88 Average = Average + pid
89 Count = Count + pid
90 mask = offset < N
92 x = tl.load(X, mask, other=0.0).to(tl.float32)
94 count = tl.sum(mask.to(tl.float32))
95 average = tl.sum(x) / count
96 acc = tl.sum(x * x) - count * average * average
98 tl.store(Average, average)
99 tl.store(Acc, acc)
100 tl.store(Count, count)
103@libentry()
104@triton.heuristics(runtime.get_heuristic_config("var_mean"))
105@triton.jit(do_not_specialize=["correction"])
106def var_mean_kernel_2(
107 Acc,
108 Average,
109 Count,
110 Var,
111 Mean,
112 N,
113 correction,
114 BLOCK_NUM,
115 BLOCK_N: tl.constexpr,
116):
117 offset = tl.arange(0, BLOCK_N)
118 mask = offset < BLOCK_NUM
119 Acc = Acc + offset
120 Average = Average + offset
121 Count = Count + offset
122 acc = tl.load(Acc, mask, other=0.0).to(tl.float32)
123 average = tl.load(Average, mask, other=0.0).to(tl.float32)
124 count = tl.load(Count, mask, other=0.0).to(tl.float32)
126 mean, _, nvar = tl.reduce((average, count, acc), axis=0, combine_fn=welford_func)
128 var = nvar / (N - correction)
129 tl.store(Mean, mean)
130 tl.store(Var, var)
133def var_mean(x, dim=None, *, correction=None, keepdim=False):
134 logger.debug("GEMS VAR MEAN")
135 if correction is None:
136 correction = 1.0
138 if dim is None or len(dim) == x.ndim:
139 dim = list(range(x.ndim))
140 shape = [1] * x.ndim
141 N = x.numel()
142 var = torch.empty(shape, dtype=x.dtype, device=x.device)
143 mean = torch.empty(shape, dtype=x.dtype, device=x.device)
144 BLOCK_N = 1024
145 BLOCK_NUM = triton.cdiv(N, BLOCK_N)
146 acc = torch.empty([BLOCK_NUM], dtype=x.dtype, device=x.device)
147 average = torch.empty([BLOCK_NUM], dtype=x.dtype, device=x.device)
148 count = torch.empty([BLOCK_NUM], dtype=x.dtype, device=x.device)
150 with torch_device_fn.device(x.device):
151 var_mean_kernel_1[(BLOCK_NUM,)](x, acc, average, count, N, BLOCK_N=BLOCK_N)
152 var_mean_kernel_2[(1,)](
153 acc, average, count, var, mean, N, correction, BLOCK_NUM
154 )
155 else:
156 shape = list(x.shape)
157 dim = [d % x.ndim for d in dim]
158 x = dim_compress(x, dim)
159 N = 1
160 for i in dim:
161 N *= shape[i]
162 shape[i] = 1
163 M = x.numel() // N
164 var = torch.empty(shape, dtype=x.dtype, device=x.device)
165 mean = torch.empty(shape, dtype=x.dtype, device=x.device)
167 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),)
168 with torch_device_fn.device(x.device):
169 var_mean_welford_kernel[grid](x, var, mean, M, N, correction)
171 if not keepdim:
172 var = var.squeeze(dim=dim)
173 mean = mean.squeeze(dim=dim)
174 return var, mean