Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/var_mean.py: 0%
120 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7# from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import dim_compress, libentry
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15def heur_block_m(args):
16 return triton.next_power_of_2(triton.cdiv(args["M"], 12))
19def heur_block_n(args):
20 return min(8192, triton.next_power_of_2(args["N"]))
23@triton.jit
24def welford_func(mean_x, count_x, M_x, mean_y, count_y, M_y):
25 count = count_x + count_y
26 _count = tl.maximum(count, 1)
27 mc_x = mean_x * count_x
28 mc_y = mean_y * count_y
29 mean = (mc_x + mc_y) / _count
30 M = M_x + mc_x * mean_x + M_y + mc_y * mean_y - count * mean * mean
31 return mean, count, M
34@libentry()
35# @triton.autotune(configs=runtime.get_tuned_config("var_mean"), key=["M", "N"])
36@triton.heuristics(
37 {
38 "BLOCK_M": heur_block_m,
39 "BLOCK_N": heur_block_n,
40 }
41)
42@triton.jit(do_not_specialize=["correction"])
43def var_mean_welford_kernel(
44 X,
45 Var,
46 Mean,
47 M,
48 N,
49 correction,
50 BLOCK_M: tl.constexpr,
51 BLOCK_N: tl.constexpr,
52):
53 # Map the program id to the row of X it should compute.
54 pid = tle.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
55 X = X + pid * N
56 Var = Var + pid
57 Mean = Mean + pid
58 row_mask = pid < M
60 _mean = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
61 _acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
62 _count = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
63 for off in range(0, N, BLOCK_N):
64 cols = off + tl.arange(0, BLOCK_N)[None, :]
65 col_mask = cols < N
66 mask = row_mask and col_mask
68 x = tl.load(X + cols, mask, other=0.0).to(tl.float32)
70 count = _count + mask
71 cnt = tl.maximum(count, 1)
72 cur_mean = (_mean * _count + x) / cnt
73 _acc += (x - cur_mean) * (x - _mean) * mask
74 _mean = cur_mean
75 _count = count
77 mean, _, acc = tl.reduce((_mean, _count, _acc), axis=1, combine_fn=welford_func)
78 var = acc / (N - correction)
79 mean = mean[:, None]
80 var = var[:, None]
81 # Write mean / var
82 tl.store(Mean, mean, row_mask)
83 tl.store(Var, var, row_mask)
86@libentry()
87@triton.jit
88def var_mean_kernel_1(
89 X,
90 Acc,
91 Average,
92 Count,
93 N,
94 BLOCK_N: tl.constexpr,
95):
96 # Map the program id to the row of X it should compute.
97 pid = tle.program_id(0)
98 offset = pid * BLOCK_N + tl.arange(0, BLOCK_N)
100 X = X + offset
101 Acc = Acc + pid
102 Average = Average + pid
103 Count = Count + pid
104 mask = offset < N
106 x = tl.load(X, mask, other=0.0).to(tl.float32)
108 count = tl.sum(mask.to(tl.float32))
109 average = tl.sum(x) / count
110 acc = tl.sum(x * x) - count * average * average
112 tl.store(Average, average)
113 tl.store(Acc, acc)
114 tl.store(Count, count)
117def heur_block_n(args):
118 return triton.next_power_of_2(args["BLOCK_NUM"])
121@libentry()
122# @triton.heuristics(runtime.get_heuristic_config("var_mean"))
123@triton.heuristics(
124 {
125 "BLOCK_N": heur_block_n,
126 }
127)
128@triton.jit(do_not_specialize=["correction"])
129def var_mean_kernel_2(
130 Acc,
131 Average,
132 Count,
133 Var,
134 Mean,
135 N,
136 correction,
137 BLOCK_NUM,
138 BLOCK_N: tl.constexpr,
139):
140 offset = tl.arange(0, BLOCK_N)
141 mask = offset < BLOCK_NUM
142 Acc = Acc + offset
143 Average = Average + offset
144 Count = Count + offset
145 acc = tl.load(Acc, mask, other=0.0).to(tl.float32)
146 average = tl.load(Average, mask, other=0.0).to(tl.float32)
147 count = tl.load(Count, mask, other=0.0).to(tl.float32)
149 mean, _, nvar = tl.reduce((average, count, acc), axis=0, combine_fn=welford_func)
151 var = nvar / (N - correction)
152 tl.store(Mean, mean)
153 tl.store(Var, var)
156def var_mean(x, dim=None, *, correction=None, keepdim=False):
157 logger.debug("GEMS VAR MEAN")
158 if correction is None:
159 correction = 1.0
161 if dim is None or len(dim) == x.ndim:
162 dim = list(range(x.ndim))
163 shape = [1] * x.ndim
164 N = x.numel()
165 var = torch.empty(shape, dtype=x.dtype, device=x.device)
166 mean = torch.empty(shape, dtype=x.dtype, device=x.device)
167 BLOCK_N = 1024
168 BLOCK_NUM = triton.cdiv(N, BLOCK_N)
169 acc = torch.empty([BLOCK_NUM], dtype=x.dtype, device=x.device)
170 average = torch.empty([BLOCK_NUM], dtype=x.dtype, device=x.device)
171 count = torch.empty([BLOCK_NUM], dtype=x.dtype, device=x.device)
173 with torch_device_fn.device(x.device):
174 var_mean_kernel_1[(BLOCK_NUM,)](x, acc, average, count, N, BLOCK_N=BLOCK_N)
175 var_mean_kernel_2[(1,)](
176 acc,
177 average,
178 count,
179 var,
180 mean,
181 N,
182 correction,
183 BLOCK_NUM,
184 isCloseUnrollControl=True,
185 )
186 else:
187 shape = list(x.shape)
188 dim = [d % x.ndim for d in dim]
189 x = dim_compress(x, dim)
190 N = 1
191 for i in dim:
192 N *= shape[i]
193 shape[i] = 1
194 M = x.numel() // N
195 var = torch.empty(shape, dtype=x.dtype, device=x.device)
196 mean = torch.empty(shape, dtype=x.dtype, device=x.device)
198 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),)
199 with torch_device_fn.device(x.device):
200 var_mean_welford_kernel[grid](
201 x, var, mean, M, N, correction, isCloseUnrollControl=True
202 )
204 if not keepdim:
205 var = var.squeeze(dim=dim)
206 mean = mean.squeeze(dim=dim)
207 return var, mean