Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/std.py: 0%
114 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7# from flag_gems import runtime
8from flag_gems.utils import dim_compress
10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
13@triton.jit
14def _std_map_kernel(X, Tmp_sum, Tmp_sum_sq, N, BLOCK_N: tl.constexpr):
15 pid = tl.program_id(0)
16 offset = pid * BLOCK_N + tl.arange(0, BLOCK_N)
17 mask = offset < N
18 x = tl.load(X + offset, mask=mask, other=0.0).to(tl.float32)
19 sum_val = tl.sum(x, axis=0)
20 sum_sq_val = tl.sum(x * x, axis=0)
21 tl.store(Tmp_sum + pid, sum_val)
22 tl.store(Tmp_sum_sq + pid, sum_sq_val)
25@triton.jit
26def _std_reduce_kernel(
27 Tmp_sum, Tmp_sum_sq, Out, N, correction, BLOCK_NUM, BLOCK_SIZE: tl.constexpr
28):
29 total_sum_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
30 total_sum_sq_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
31 for off in range(0, BLOCK_NUM, BLOCK_SIZE):
32 offset = off + tl.arange(0, BLOCK_SIZE)
33 mask = offset < BLOCK_NUM
34 tmp_sum_vals = tl.load(Tmp_sum + offset, mask=mask, other=0.0).to(tl.float32)
35 tmp_sum_sq_vals = tl.load(Tmp_sum_sq + offset, mask=mask, other=0.0).to(
36 tl.float32
37 )
38 total_sum_acc += tmp_sum_vals
39 total_sum_sq_acc += tmp_sum_sq_vals
40 total_sum = tl.sum(total_sum_acc, axis=0)
41 total_sum_sq = tl.sum(total_sum_sq_acc, axis=0)
42 mean = total_sum / N
43 var = (total_sum_sq / N) - (mean * mean)
44 var = var * N / tl.maximum(N - correction, 1.0)
45 safe_var = tl.maximum(var, 0.0)
46 std_dev = tl.sqrt(safe_var)
47 tl.store(Out, std_dev.to(Out.dtype.element_ty))
50def _std_fused_dim_kernel_m(args):
51 return triton.cdiv(args["M"], 12) # cluster_num
52 # return triton.next_power_of_2(triton.cdiv(args["M"], 12))
55def _std_fused_dim_kernel_n(args):
56 import builtins
58 return builtins.min(args["N"], 8192)
61# @triton.autotune(configs=runtime.get_tuned_config("naive_reduction"), key=["M", "N"])
62@triton.heuristics(
63 values={
64 "BLOCK_M": _std_fused_dim_kernel_m,
65 "BLOCK_N": _std_fused_dim_kernel_n,
66 },
67)
68@triton.jit
69def _std_fused_dim_kernel(
70 X,
71 Out,
72 stride_x_row,
73 stride_x_col,
74 M,
75 N,
76 correction,
77 BLOCK_M: tl.constexpr,
78 BLOCK_N: tl.constexpr,
79):
80 pid_group = tl.program_id(axis=0)
81 start_row = pid_group * BLOCK_M
82 row_offsets = start_row + tl.arange(0, BLOCK_M)
83 row_mask = row_offsets < M
85 mean_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
86 x_row_ptrs = X + row_offsets[:, None] * stride_x_row
88 for off in range(0, N, BLOCK_N):
89 col_offsets = off + tl.arange(0, BLOCK_N)
90 col_mask = col_offsets < N
91 x_ptrs = x_row_ptrs + col_offsets[None, :] * stride_x_col
92 final_mask = row_mask[:, None] & col_mask[None, :]
93 x = tl.load(x_ptrs, mask=final_mask, other=0.0)
94 mean_acc += x.to(tl.float32)
96 mean = tl.sum(mean_acc, axis=1) / N
98 var_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
99 for off in range(0, N, BLOCK_N):
100 col_offsets = off + tl.arange(0, BLOCK_N)
101 col_mask = col_offsets < N
102 x_ptrs = x_row_ptrs + col_offsets[None, :] * stride_x_col
103 final_mask = row_mask[:, None] & col_mask[None, :]
104 x = tl.load(x_ptrs, mask=final_mask, other=0.0)
105 diff = x.to(tl.float32) - mean[:, None]
106 var_acc += tl.where(final_mask, diff * diff, 0.0)
108 var = tl.sum(var_acc, axis=1)
110 denom = N - correction
111 var = var / tl.maximum(denom, 1e-12)
112 safe_var = tl.maximum(var, 0.0)
113 std_dev = tl.sqrt(safe_var)
115 out_ptrs = Out + row_offsets
116 tl.store(out_ptrs, std_dev.to(Out.dtype.element_ty), mask=row_mask)
119def std(x, dim=None, *, correction=None, keepdim=False):
120 effective_correction = 1.0 if correction is None else float(correction)
121 original_shape = x.shape
122 input_ndim = x.ndim
124 if dim is None:
125 logger.debug("GEMS STD (Global Simple Map-Reduce Path)")
126 N = x.numel()
127 if N == 0 or N - effective_correction <= 0:
128 return torch.full([], float("nan"), device=x.device, dtype=x.dtype)
130 BLOCK_N_MAP = 1024
131 BLOCK_NUM = triton.cdiv(N, BLOCK_N_MAP)
132 tmp_sum = torch.empty((BLOCK_NUM,), dtype=torch.float32, device=x.device)
133 tmp_sum_sq = torch.empty((BLOCK_NUM,), dtype=torch.float32, device=x.device)
134 _std_map_kernel[(BLOCK_NUM,)](
135 x.contiguous(), tmp_sum, tmp_sum_sq, N, BLOCK_N_MAP
136 )
137 out = torch.empty([], device=x.device, dtype=x.dtype)
138 BLOCK_SIZE_REDUCE = 1024
139 _std_reduce_kernel[(1,)](
140 tmp_sum,
141 tmp_sum_sq,
142 out,
143 N,
144 effective_correction,
145 BLOCK_NUM,
146 BLOCK_SIZE_REDUCE,
147 )
148 return out.view([1] * input_ndim) if keepdim else out
150 else:
151 logger.warning(
152 f"GEMS std: Using compatible but non-optimal path for dim={dim} (dim_compress)."
153 )
155 if isinstance(dim, int):
156 dim_list = [dim]
157 else:
158 dim_list = list(dim)
159 dim_list_normalized = [d % input_ndim for d in dim_list]
161 x_view = dim_compress(x, dim_list_normalized)
163 N = 1
164 for d in dim_list_normalized:
165 N *= original_shape[d]
166 M = x.numel() // N
168 stride_x_row, stride_x_col = N, 1
170 output_shape_kept = list(original_shape)
171 for d in dim_list_normalized:
172 output_shape_kept[d] = 1
174 if M * N > 0 and (N - effective_correction <= 0):
175 final_shape = [
176 s for i, s in enumerate(original_shape) if i not in dim_list_normalized
177 ]
178 return torch.full(
179 final_shape if not keepdim else output_shape_kept,
180 float("nan"),
181 device=x.device,
182 dtype=x.dtype,
183 )
185 out = torch.empty(output_shape_kept, device=x.device, dtype=x.dtype)
186 if M * N == 0:
187 return out.squeeze(dim=tuple(dim_list_normalized)) if not keepdim else out
189 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),)
191 _std_fused_dim_kernel[grid](
192 x_view, out.view(M), stride_x_row, stride_x_col, M, N, effective_correction
193 )
195 return out.squeeze(dim=tuple(dim_list_normalized)) if not keepdim else out