Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/std.py: 0%

114 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-11 02:28 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7# from flag_gems import runtime 

8from flag_gems.utils import dim_compress 

9 

10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

11 

12 

13@triton.jit 

14def _std_map_kernel(X, Tmp_sum, Tmp_sum_sq, N, BLOCK_N: tl.constexpr): 

15 pid = tl.program_id(0) 

16 offset = pid * BLOCK_N + tl.arange(0, BLOCK_N) 

17 mask = offset < N 

18 x = tl.load(X + offset, mask=mask, other=0.0).to(tl.float32) 

19 sum_val = tl.sum(x, axis=0) 

20 sum_sq_val = tl.sum(x * x, axis=0) 

21 tl.store(Tmp_sum + pid, sum_val) 

22 tl.store(Tmp_sum_sq + pid, sum_sq_val) 

23 

24 

25@triton.jit 

26def _std_reduce_kernel( 

27 Tmp_sum, Tmp_sum_sq, Out, N, correction, BLOCK_NUM, BLOCK_SIZE: tl.constexpr 

28): 

29 total_sum_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) 

30 total_sum_sq_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) 

31 for off in range(0, BLOCK_NUM, BLOCK_SIZE): 

32 offset = off + tl.arange(0, BLOCK_SIZE) 

33 mask = offset < BLOCK_NUM 

34 tmp_sum_vals = tl.load(Tmp_sum + offset, mask=mask, other=0.0).to(tl.float32) 

35 tmp_sum_sq_vals = tl.load(Tmp_sum_sq + offset, mask=mask, other=0.0).to( 

36 tl.float32 

37 ) 

38 total_sum_acc += tmp_sum_vals 

39 total_sum_sq_acc += tmp_sum_sq_vals 

40 total_sum = tl.sum(total_sum_acc, axis=0) 

41 total_sum_sq = tl.sum(total_sum_sq_acc, axis=0) 

42 mean = total_sum / N 

43 var = (total_sum_sq / N) - (mean * mean) 

44 var = var * N / tl.maximum(N - correction, 1.0) 

45 safe_var = tl.maximum(var, 0.0) 

46 std_dev = tl.sqrt(safe_var) 

47 tl.store(Out, std_dev.to(Out.dtype.element_ty)) 

48 

49 

50def _std_fused_dim_kernel_m(args): 

51 return triton.cdiv(args["M"], 12) # cluster_num 

52 # return triton.next_power_of_2(triton.cdiv(args["M"], 12)) 

53 

54 

55def _std_fused_dim_kernel_n(args): 

56 import builtins 

57 

58 return builtins.min(args["N"], 8192) 

59 

60 

61# @triton.autotune(configs=runtime.get_tuned_config("naive_reduction"), key=["M", "N"]) 

62@triton.heuristics( 

63 values={ 

64 "BLOCK_M": _std_fused_dim_kernel_m, 

65 "BLOCK_N": _std_fused_dim_kernel_n, 

66 }, 

67) 

68@triton.jit 

69def _std_fused_dim_kernel( 

70 X, 

71 Out, 

72 stride_x_row, 

73 stride_x_col, 

74 M, 

75 N, 

76 correction, 

77 BLOCK_M: tl.constexpr, 

78 BLOCK_N: tl.constexpr, 

79): 

80 pid_group = tl.program_id(axis=0) 

81 start_row = pid_group * BLOCK_M 

82 row_offsets = start_row + tl.arange(0, BLOCK_M) 

83 row_mask = row_offsets < M 

84 

85 mean_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

86 x_row_ptrs = X + row_offsets[:, None] * stride_x_row 

87 

88 for off in range(0, N, BLOCK_N): 

89 col_offsets = off + tl.arange(0, BLOCK_N) 

90 col_mask = col_offsets < N 

91 x_ptrs = x_row_ptrs + col_offsets[None, :] * stride_x_col 

92 final_mask = row_mask[:, None] & col_mask[None, :] 

93 x = tl.load(x_ptrs, mask=final_mask, other=0.0) 

94 mean_acc += x.to(tl.float32) 

95 

96 mean = tl.sum(mean_acc, axis=1) / N 

97 

98 var_acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

99 for off in range(0, N, BLOCK_N): 

100 col_offsets = off + tl.arange(0, BLOCK_N) 

101 col_mask = col_offsets < N 

102 x_ptrs = x_row_ptrs + col_offsets[None, :] * stride_x_col 

103 final_mask = row_mask[:, None] & col_mask[None, :] 

104 x = tl.load(x_ptrs, mask=final_mask, other=0.0) 

105 diff = x.to(tl.float32) - mean[:, None] 

106 var_acc += tl.where(final_mask, diff * diff, 0.0) 

107 

108 var = tl.sum(var_acc, axis=1) 

109 

110 denom = N - correction 

111 var = var / tl.maximum(denom, 1e-12) 

112 safe_var = tl.maximum(var, 0.0) 

113 std_dev = tl.sqrt(safe_var) 

114 

115 out_ptrs = Out + row_offsets 

116 tl.store(out_ptrs, std_dev.to(Out.dtype.element_ty), mask=row_mask) 

117 

118 

119def std(x, dim=None, *, correction=None, keepdim=False): 

120 effective_correction = 1.0 if correction is None else float(correction) 

121 original_shape = x.shape 

122 input_ndim = x.ndim 

123 

124 if dim is None: 

125 logger.debug("GEMS STD (Global Simple Map-Reduce Path)") 

126 N = x.numel() 

127 if N == 0 or N - effective_correction <= 0: 

128 return torch.full([], float("nan"), device=x.device, dtype=x.dtype) 

129 

130 BLOCK_N_MAP = 1024 

131 BLOCK_NUM = triton.cdiv(N, BLOCK_N_MAP) 

132 tmp_sum = torch.empty((BLOCK_NUM,), dtype=torch.float32, device=x.device) 

133 tmp_sum_sq = torch.empty((BLOCK_NUM,), dtype=torch.float32, device=x.device) 

134 _std_map_kernel[(BLOCK_NUM,)]( 

135 x.contiguous(), tmp_sum, tmp_sum_sq, N, BLOCK_N_MAP 

136 ) 

137 out = torch.empty([], device=x.device, dtype=x.dtype) 

138 BLOCK_SIZE_REDUCE = 1024 

139 _std_reduce_kernel[(1,)]( 

140 tmp_sum, 

141 tmp_sum_sq, 

142 out, 

143 N, 

144 effective_correction, 

145 BLOCK_NUM, 

146 BLOCK_SIZE_REDUCE, 

147 ) 

148 return out.view([1] * input_ndim) if keepdim else out 

149 

150 else: 

151 logger.warning( 

152 f"GEMS std: Using compatible but non-optimal path for dim={dim} (dim_compress)." 

153 ) 

154 

155 if isinstance(dim, int): 

156 dim_list = [dim] 

157 else: 

158 dim_list = list(dim) 

159 dim_list_normalized = [d % input_ndim for d in dim_list] 

160 

161 x_view = dim_compress(x, dim_list_normalized) 

162 

163 N = 1 

164 for d in dim_list_normalized: 

165 N *= original_shape[d] 

166 M = x.numel() // N 

167 

168 stride_x_row, stride_x_col = N, 1 

169 

170 output_shape_kept = list(original_shape) 

171 for d in dim_list_normalized: 

172 output_shape_kept[d] = 1 

173 

174 if M * N > 0 and (N - effective_correction <= 0): 

175 final_shape = [ 

176 s for i, s in enumerate(original_shape) if i not in dim_list_normalized 

177 ] 

178 return torch.full( 

179 final_shape if not keepdim else output_shape_kept, 

180 float("nan"), 

181 device=x.device, 

182 dtype=x.dtype, 

183 ) 

184 

185 out = torch.empty(output_shape_kept, device=x.device, dtype=x.dtype) 

186 if M * N == 0: 

187 return out.squeeze(dim=tuple(dim_list_normalized)) if not keepdim else out 

188 

189 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),) 

190 

191 _std_fused_dim_kernel[grid]( 

192 x_view, out.view(M), stride_x_row, stride_x_col, M, N, effective_correction 

193 ) 

194 

195 return out.squeeze(dim=tuple(dim_list_normalized)) if not keepdim else out