Coverage for src/flag_gems/ops/mean.py: 17%
193 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
1import logging
2import math
3from functools import reduce
5import torch
6import triton
7import triton.language as tl
9from flag_gems import runtime
10from flag_gems.runtime import torch_device_fn
11from flag_gems.utils import dim_compress, libentry, libtuner
12from flag_gems.utils import triton_lang_extension as tle
14logger = logging.getLogger(__name__)
17@libentry()
18@triton.jit
19def mean_kernel_1(
20 inp,
21 mid,
22 M,
23 BLOCK_SIZE: tl.constexpr,
24):
25 # accumulation dtype
26 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr(
27 inp.dtype.element_ty == tl.bfloat16
28 ):
29 cdtype = tl.float32
30 else:
31 cdtype = inp.dtype.element_ty
33 pid = tle.program_id(0)
34 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
35 inp_ptrs = inp + offset
36 mask = offset < M
38 inp_val = tl.load(inp_ptrs, mask=mask, other=0).to(cdtype)
39 sum_val = tl.sum(inp_val)
40 mid_ptr = mid + pid
41 tl.store(mid_ptr, sum_val)
44@libentry()
45@triton.jit
46def mean_kernel_2(mid, out, M, MID_SIZE, BLOCK_MID: tl.constexpr):
47 if tl.constexpr(mid.dtype.element_ty == tl.float16) or tl.constexpr(
48 mid.dtype.element_ty == tl.bfloat16
49 ):
50 cdtype = tl.float32
51 else:
52 cdtype = mid.dtype.element_ty
54 offset = tl.arange(0, BLOCK_MID)
55 mid_ptrs = mid + offset
56 mask = offset < MID_SIZE
57 mid_val = tl.load(mid_ptrs, mask=mask, other=0).to(cdtype)
58 sum_val = tl.sum(mid_val)
59 # divide by total element count M to get mean
60 mean_val = sum_val / M
61 tl.store(out, mean_val)
64def mean(inp, *, dtype=None):
65 logger.debug("GEMS MEAN")
66 inp = inp.contiguous()
67 M = inp.numel()
68 if dtype is None:
69 dtype = inp.dtype
70 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
71 mid_size = triton.cdiv(M, block_size)
72 block_mid = triton.next_power_of_2(mid_size)
74 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)
75 out = torch.empty([], dtype=dtype, device=inp.device)
77 with torch_device_fn.device(inp.device):
78 mean_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size)
79 mean_kernel_2[(1, 1, 1)](mid, out, M, mid_size, block_mid)
80 return out
83@libentry()
84@triton.heuristics(runtime.get_heuristic_config("mean_non_inner"))
85@triton.jit
86def mean_dim_kernel_non_inner(
87 output_ptr,
88 input_ptr,
89 M,
90 N,
91 K,
92 TILE_N: tl.constexpr,
93 TILE_K: tl.constexpr,
94 ONE_TILE_PER_CTA: tl.constexpr,
95):
96 # accumulation dtype
97 if tl.constexpr(input_ptr.dtype.element_ty == tl.float16) or tl.constexpr(
98 input_ptr.dtype.element_ty == tl.bfloat16
99 ):
100 cdtype = tl.float32
101 else:
102 cdtype = input_ptr.dtype.element_ty
104 pid_m = tle.program_id(0)
105 pid_k = tle.program_id(1)
107 k_offsets = pid_k * TILE_K + tl.arange(0, TILE_K)[None, :]
109 if ONE_TILE_PER_CTA:
110 n_offsets = tl.arange(0, TILE_N)[:, None]
111 inp_offset = pid_m * N * K + n_offsets * K + k_offsets
112 mask = (n_offsets < N) & (k_offsets < K)
113 input_ptrs = input_ptr + inp_offset
114 inp = tl.load(input_ptrs, mask=mask, other=0).to(cdtype)
115 # sum along reduction axis (N) -> keep dims so axis 0 corresponds to TILE_K
116 summed = tl.sum(inp, axis=0, keep_dims=True)
117 # divide by N to get mean
118 out = summed / N
119 out_offset = pid_m * K + k_offsets
120 output_ptrs = output_ptr + out_offset
121 tl.store(output_ptrs, out, mask=k_offsets < K)
122 else:
123 sum_tile = tl.zeros([TILE_N, TILE_K], dtype=cdtype)
124 for start_n in range(0, N, TILE_N):
125 n_offsets = start_n + tl.arange(0, TILE_N)[:, None]
126 inp_offsets = pid_m * N * K + n_offsets * K + k_offsets
127 mask = (n_offsets < N) & (k_offsets < K)
128 inp = tl.load(input_ptr + inp_offsets, mask=mask, other=0).to(cdtype)
129 sum_tile += inp
130 summed = tl.sum(sum_tile, axis=0, keep_dims=True)
131 out = summed / N
132 out_offset = pid_m * K + k_offsets
133 output_ptrs = output_ptr + out_offset
134 tl.store(output_ptrs, out, mask=k_offsets < K)
137@libentry()
138@triton.heuristics(runtime.get_heuristic_config("softmax_inner"))
139@triton.jit
140def mean_dim_kernel_inner(
141 output_ptr,
142 input_ptr,
143 M,
144 N,
145 TILE_N: tl.constexpr,
146 ONE_TILE_PER_CTA: tl.constexpr,
147):
148 if tl.constexpr(input_ptr.dtype.element_ty == tl.float16) or tl.constexpr(
149 input_ptr.dtype.element_ty == tl.bfloat16
150 ):
151 cdtype = tl.float32
152 else:
153 cdtype = input_ptr.dtype.element_ty
155 pid_m = tle.program_id(0)
156 if ONE_TILE_PER_CTA:
157 n_offsets = tl.arange(0, TILE_N)
158 inp_offset = pid_m * N + n_offsets
159 input_ptrs = input_ptr + inp_offset
160 mask = n_offsets < N
161 inp = tl.load(input_ptrs, mask=mask, other=0).to(cdtype)
162 summed = tl.sum(inp, axis=0)
163 out = summed / N
164 out_offset = pid_m
165 output_ptrs = output_ptr + out_offset
166 tl.store(output_ptrs, out)
167 else:
168 sum_vec = tl.zeros(
169 [
170 TILE_N,
171 ],
172 dtype=cdtype,
173 )
174 for start_n in range(0, N, TILE_N):
175 n_offsets = start_n + tl.arange(0, TILE_N)
176 inp_offsets = pid_m * N + n_offsets
177 mask = n_offsets < N
178 inp = tl.load(input_ptr + inp_offsets, mask=mask, other=0).to(cdtype)
179 sum_vec += inp
180 summed = tl.sum(sum_vec, axis=0)
181 out = summed / N
182 out_offset = pid_m
183 output_ptrs = output_ptr + out_offset
184 tl.store(output_ptrs, out)
187@libentry()
188@libtuner(
189 configs=runtime.get_tuned_config("naive_reduction"),
190 key=["M", "N"],
191)
192@triton.jit
193def mean_dim_kernel(
194 inp,
195 out,
196 M,
197 N,
198 BLOCK_M: tl.constexpr,
199 BLOCK_N: tl.constexpr,
200):
201 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr(
202 inp.dtype.element_ty == tl.bfloat16
203 ):
204 cdtype = tl.float32
205 else:
206 cdtype = inp.dtype.element_ty
208 # Map the program id to the row of inp it should compute.
209 pid = tle.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
210 inp = inp + pid * N
211 out = out + pid
212 row_mask = pid < M
214 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=cdtype)
215 for off in range(0, N, BLOCK_N):
216 cols = off + tl.arange(0, BLOCK_N)[None, :]
217 col_mask = cols < N
218 mask = row_mask and col_mask
220 a = tl.load(inp + cols, mask, other=0).to(cdtype)
221 _sum += a
222 summed = tl.sum(_sum, axis=1)[:, None]
223 mean = summed / N
224 tl.store(out, mean, row_mask)
227def mean_dim_comm(inp, dim=None, keepdim=False, *, dtype=None, out=None):
228 logger.debug("GEMS MEAN_DIM")
229 if dtype is None:
230 dtype = inp.dtype
231 if dtype is torch.bool:
232 inp = inp.to(torch.int64)
233 dtype = torch.int64
235 if dim == []:
236 # mean over all elements
237 if not keepdim:
238 return mean(inp, dtype=dtype)
239 else:
240 dim_num = inp.ndim
241 return torch.reshape(mean(inp, dtype=dtype), [1] * dim_num)
243 shape = list(inp.shape)
245 # -------- normalize dim to a list of ints --------
246 if isinstance(dim, int):
247 dim = [dim]
248 else:
249 try:
250 dim = list(dim)
251 except TypeError:
252 raise TypeError(
253 f"dim must be an int, iterable of ints, or [], got {type(dim)}"
254 )
256 dim = [d % inp.ndim for d in dim]
257 # -------------------------------------------------
259 if len(dim) == 1:
260 dim0 = dim[0]
261 N = inp.shape[dim0] # reduction length
262 # product of dims before dim0; use initializer 1 for empty slice
263 M = reduce(lambda x, y: x * y, shape[:dim0], 1)
264 inp = inp.contiguous()
265 K = inp.numel() // M // N
266 shape[dim0] = 1
267 if out is None:
268 out = torch.empty(shape, dtype=dtype, device=inp.device)
270 with torch_device_fn.device(inp.device):
271 if K > 1:
272 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1)
273 mean_dim_kernel_non_inner[grid](
274 out,
275 inp,
276 M,
277 N,
278 K,
279 )
280 else:
281 grid = (M, 1, 1)
282 mean_dim_kernel_inner[grid](
283 out,
284 inp,
285 M,
286 N,
287 )
288 if not keepdim:
289 out = out.squeeze(dim=dim0)
290 return out
291 else:
292 inp = dim_compress(inp, dim)
293 N = 1
294 for i in dim:
295 N *= shape[i]
296 shape[i] = 1
297 M = inp.numel() // N
298 if out is None:
299 out = torch.empty(shape, dtype=dtype, device=inp.device)
301 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
302 with torch_device_fn.device(inp.device):
303 mean_dim_kernel[grid](inp, out, M, N)
304 if not keepdim:
305 out = out.squeeze(dim=dim)
306 return out
309def mean_dim(inp, dim=None, keepdim=False, *, dtype=None):
310 logger.debug("GEMS MEAN_DIM (wrapper)")
312 return mean_dim_comm(inp, dim, keepdim, dtype=dtype)