Coverage for src/flag_gems/runtime/backend/_cambricon/ops/var_mean.py: 0%
160 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import dim_compress, libentry
11from ..utils import (
12 MAX_NRAM_SIZE,
13 TOTAL_CORE_NUM,
14 cfggen_reduce_op,
15 count_divisible_by_2,
16)
18logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
21@triton.jit
22def welford_func(mean_x, count_x, M_x, mean_y, count_y, M_y):
23 count = count_x + count_y
24 _count = tl.maximum(count, 1)
25 mc_x = mean_x * count_x
26 mc_y = mean_y * count_y
27 mean = (mc_x + mc_y) / _count
28 M = M_x + mc_x * mean_x + M_y + mc_y * mean_y - count * mean * mean
29 return mean, count, M
32@libentry()
33@triton.autotune(configs=runtime.get_tuned_config("var_mean"), key=["M", "N"])
34@triton.jit(do_not_specialize=["correction"])
35def var_mean_welford_kernel(
36 X,
37 Var,
38 Mean,
39 M,
40 N,
41 correction,
42 BLOCK_M: tl.constexpr,
43 BLOCK_N: tl.constexpr,
44):
45 # Map the program id to the row of X it should compute.
46 num_prog = tl.num_programs(0)
47 task_num = tl.cdiv(M, BLOCK_M)
48 iter_num = tl.cdiv(task_num, num_prog)
49 for i in range(0, iter_num):
50 pid = (i * num_prog + tl.program_id(0)) * BLOCK_M + tl.arange(0, BLOCK_M)[
51 :, None
52 ]
53 X_ptr = X + pid * N
54 Var_ptr = Var + pid
55 Mean_ptr = Mean + pid
56 row_mask = pid < M
58 _mean = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
59 _acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
60 _count = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
61 for off in range(0, N, BLOCK_N):
62 cols = off + tl.arange(0, BLOCK_N)[None, :]
63 col_mask = cols < N
64 mask = row_mask and col_mask
66 x = tl.load(X_ptr + cols, mask, other=0.0).to(tl.float32)
68 count = _count + mask
69 cnt = tl.maximum(count, 1)
70 cur_mean = (_mean * _count + x) / cnt
71 _acc += (x - cur_mean) * (x - _mean) * mask
72 _mean = cur_mean
73 _count = count
75 mean, _, acc = tl.reduce((_mean, _count, _acc), axis=1, combine_fn=welford_func)
76 var = acc / (N - correction)
77 mean = mean[:, None]
78 var = var[:, None]
79 # Write mean / var
80 tl.store(Mean_ptr, mean, row_mask)
81 tl.store(Var_ptr, var, row_mask)
84def prune_varmean_config(configs, named_args, **kwargs):
85 M = named_args["M"]
86 pruned_configs = []
87 for config in configs:
88 BLOCK_SIZE = config.kwargs["BLOCK_SIZE"]
89 num_stages = config.num_stages
90 num_block = M // BLOCK_SIZE
91 if num_block < 1:
92 continue
93 if num_block < TOTAL_CORE_NUM:
94 # A core must process a BLOCK_SIZE of data.
95 if num_stages > 1:
96 continue
97 alloc_num = 3
98 else:
99 alloc_num = 6
100 # Set f32 as the default type.
101 if BLOCK_SIZE * 4 * alloc_num < MAX_NRAM_SIZE:
102 pruned_configs.append(config)
103 # If M < 512, append the default config.
104 if len(pruned_configs) == 0:
105 pruned_configs.append(
106 triton.Config({"BLOCK_SIZE": 512}, num_warps=1, num_stages=1)
107 )
108 return pruned_configs
111@libentry()
112@triton.autotune(
113 configs=cfggen_reduce_op(),
114 prune_configs_by={"early_config_prune": prune_varmean_config},
115 key=["M"],
116 reset_to_zero=["Acc", "Average", "Count"],
117)
118@triton.heuristics(
119 values={
120 "ONE_TILE_PER_CTA": lambda args: args["M"]
121 <= args["BLOCK_SIZE"] * TOTAL_CORE_NUM
122 },
123)
124@triton.jit
125def var_mean_kernel_1(
126 X, Acc, Average, Count, M, BLOCK_SIZE: tl.constexpr, ONE_TILE_PER_CTA: tl.constexpr
127):
128 # Map the program id to the row of X it should compute.
129 pid = tl.program_id(0)
130 block_start = pid * BLOCK_SIZE
132 count = 0.0
133 average = 0.0
134 acc = 0.0
135 if ONE_TILE_PER_CTA:
136 offsets = block_start + tl.arange(0, BLOCK_SIZE)
137 mask = offsets < M
138 x = tl.load(X + offsets, mask, other=0.0).to(tl.float32)
139 count = tl.sum(mask.to(tl.float32))
140 average = tl.sum(x) / count
141 acc = tl.sum(x * x) - count * average * average
142 else:
143 _tmp1 = tl.zeros([BLOCK_SIZE], tl.float32)
144 _tmp2 = tl.zeros([BLOCK_SIZE], tl.float32)
145 num_jobs = tl.num_programs(axis=0)
146 step = num_jobs * BLOCK_SIZE
147 for block_start_offset in range(block_start, M, step):
148 offsets = block_start_offset + tl.arange(0, BLOCK_SIZE)
149 mask = offsets < M
150 x = tl.load(X + offsets, mask, other=0.0).to(tl.float32)
151 _count = tl.sum(mask.to(tl.float32))
152 count = count + _count
153 _tmp1 = _tmp1 + x
154 _tmp2 = _tmp2 + x * x
155 count = tl.maximum(count, 1)
156 average = tl.sum(_tmp1) / count
157 acc = tl.sum(_tmp2) - count * average * average
159 Acc = Acc + pid
160 Average = Average + pid
161 Count = Count + pid
163 tl.store(Average, average)
164 tl.store(Acc, acc)
165 tl.store(Count, count)
168@libentry()
169@triton.heuristics(runtime.get_heuristic_config("var_mean"))
170@triton.jit(do_not_specialize=["correction"])
171def var_mean_kernel_2(
172 Acc,
173 Average,
174 Count,
175 Var,
176 Mean,
177 M,
178 correction,
179 BLOCK_NUM: tl.constexpr,
180 ITER_NUM: tl.constexpr,
181):
182 offset = tl.arange(0, BLOCK_NUM)
183 Acc = Acc + offset
184 Average = Average + offset
185 Count = Count + offset
186 acc = tl.load(Acc)
187 average = tl.load(Average)
188 count = tl.load(Count)
190 for x in tl.static_range(1, ITER_NUM, 1):
191 (
192 average[: BLOCK_NUM // (2**x)],
193 count[: BLOCK_NUM // (2**x)],
194 acc[: BLOCK_NUM // (2**x)],
195 ) = welford_func(
196 average[: BLOCK_NUM // (2**x)],
197 count[: BLOCK_NUM // (2**x)],
198 acc[: BLOCK_NUM // (2**x)],
199 average[BLOCK_NUM // (2**x) : (BLOCK_NUM // (2**x)) * 2],
200 count[BLOCK_NUM // (2**x) : (BLOCK_NUM // (2**x)) * 2],
201 acc[BLOCK_NUM // (2**x) : (BLOCK_NUM // (2**x)) * 2],
202 )
203 mean, _, nvar = tl.reduce(
204 (
205 average[: BLOCK_NUM // (2 ** (ITER_NUM - 1))],
206 count[: BLOCK_NUM // (2 ** (ITER_NUM - 1))],
207 acc[: BLOCK_NUM // (2 ** (ITER_NUM - 1))],
208 ),
209 axis=0,
210 combine_fn=welford_func,
211 )
213 # FIXME: Reset to original reduce programming mode after optimizing the tl.reduce.
214 # mean, _, nvar = tl.reduce((average, count, acc), axis=0, combine_fn=welford_func)
216 var = nvar / (M - correction)
217 tl.store(Mean, mean)
218 tl.store(Var, var)
221def var_mean(x, dim=None, *, correction=None, keepdim=False):
222 logger.debug("GEMS_CAMBRICON VAR MEAN")
223 if correction is None:
224 correction = 1.0
226 if dim is None or len(dim) == x.ndim:
227 dim = list(range(x.ndim))
228 shape = [1] * x.ndim
229 M = x.numel()
231 grid = lambda meta: (min(triton.cdiv(M, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),)
232 var = torch.empty(shape, dtype=x.dtype, device=x.device)
233 mean = torch.empty(shape, dtype=x.dtype, device=x.device)
234 acc = torch.zeros([TOTAL_CORE_NUM], dtype=torch.float, device=x.device)
235 average = torch.zeros([TOTAL_CORE_NUM], dtype=torch.float, device=x.device)
236 count = torch.zeros([TOTAL_CORE_NUM], dtype=torch.float, device=x.device)
237 loop_num = count_divisible_by_2(TOTAL_CORE_NUM) + 1
239 with torch_device_fn.device(x.device):
240 var_mean_kernel_1[grid](x, acc, average, count, M)
241 var_mean_kernel_2[(1,)](
242 acc,
243 average,
244 count,
245 var,
246 mean,
247 M,
248 correction,
249 BLOCK_NUM=TOTAL_CORE_NUM,
250 ITER_NUM=loop_num,
251 )
252 else:
253 shape = list(x.shape)
254 dim = [d % x.ndim for d in dim]
255 x = dim_compress(x, dim)
256 N = 1
257 for i in dim:
258 N *= shape[i]
259 shape[i] = 1
260 M = x.numel() // N
261 var = torch.empty(shape, dtype=x.dtype, device=x.device)
262 mean = torch.empty(shape, dtype=x.dtype, device=x.device)
264 grid = lambda META: (min(triton.cdiv(M, META["BLOCK_M"]), TOTAL_CORE_NUM),)
265 with torch_device_fn.device(x.device):
266 var_mean_welford_kernel[grid](x, var, mean, M, N, correction)
268 if not keepdim:
269 var = var.squeeze(dim=dim)
270 mean = mean.squeeze(dim=dim)
271 return var, mean