Coverage for src/flag_gems/ops/var_mean.py: 49%

115 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-16 02:02 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import dim_compress, libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15@triton.jit 

16def welford_func(mean_x, count_x, M_x, mean_y, count_y, M_y): 

17 count = count_x + count_y 

18 _count = tl.maximum(count, 1) 

19 mc_x = mean_x * count_x 

20 mc_y = mean_y * count_y 

21 mean = (mc_x + mc_y) / _count 

22 M = M_x + mc_x * mean_x + M_y + mc_y * mean_y - count * mean * mean 

23 return mean, count, M 

24 

25 

26@libentry() 

27@triton.autotune(configs=runtime.get_tuned_config("var_mean"), key=["M", "N"]) 

28@triton.jit(do_not_specialize=["correction"]) 

29def var_mean_welford_kernel( 

30 X, 

31 Var, 

32 Mean, 

33 M, 

34 N, 

35 correction, 

36 BLOCK_M: tl.constexpr, 

37 BLOCK_N: tl.constexpr, 

38): 

39 # Map the program id to the row of X it should compute. 

40 pid = tle.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

41 X = X + pid * N 

42 Var = Var + pid 

43 Mean = Mean + pid 

44 row_mask = pid < M 

45 

46 _mean = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

47 _acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

48 _count = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

49 for off in range(0, N, BLOCK_N): 

50 cols = off + tl.arange(0, BLOCK_N)[None, :] 

51 col_mask = cols < N 

52 mask = row_mask and col_mask 

53 

54 x = tl.load(X + cols, mask, other=0.0).to(tl.float32) 

55 

56 count = _count + mask 

57 cnt = tl.maximum(count, 1) 

58 cur_mean = (_mean * _count + x) / cnt 

59 _acc += (x - cur_mean) * (x - _mean) * mask 

60 _mean = cur_mean 

61 _count = count 

62 

63 mean, _, acc = tl.reduce((_mean, _count, _acc), axis=1, combine_fn=welford_func) 

64 var = acc / (N - correction) 

65 mean = mean[:, None] 

66 var = var[:, None] 

67 # Write mean / var 

68 tl.store(Mean, mean, row_mask) 

69 tl.store(Var, var, row_mask) 

70 

71 

72@libentry() 

73@triton.jit 

74def var_mean_kernel_1( 

75 X, 

76 Acc, 

77 Average, 

78 Count, 

79 N, 

80 BLOCK_N: tl.constexpr, 

81): 

82 # Map the program id to the row of X it should compute. 

83 pid = tle.program_id(0) 

84 offset = pid * BLOCK_N + tl.arange(0, BLOCK_N) 

85 

86 X = X + offset 

87 Acc = Acc + pid 

88 Average = Average + pid 

89 Count = Count + pid 

90 mask = offset < N 

91 

92 x = tl.load(X, mask, other=0.0).to(tl.float32) 

93 

94 count = tl.sum(mask.to(tl.float32)) 

95 average = tl.sum(x) / count 

96 acc = tl.sum(x * x) - count * average * average 

97 

98 tl.store(Average, average) 

99 tl.store(Acc, acc) 

100 tl.store(Count, count) 

101 

102 

103@libentry() 

104@triton.heuristics(runtime.get_heuristic_config("var_mean")) 

105@triton.jit(do_not_specialize=["correction"]) 

106def var_mean_kernel_2( 

107 Acc, 

108 Average, 

109 Count, 

110 Var, 

111 Mean, 

112 N, 

113 correction, 

114 BLOCK_NUM, 

115 BLOCK_N: tl.constexpr, 

116): 

117 offset = tl.arange(0, BLOCK_N) 

118 mask = offset < BLOCK_NUM 

119 Acc = Acc + offset 

120 Average = Average + offset 

121 Count = Count + offset 

122 acc = tl.load(Acc, mask, other=0.0).to(tl.float32) 

123 average = tl.load(Average, mask, other=0.0).to(tl.float32) 

124 count = tl.load(Count, mask, other=0.0).to(tl.float32) 

125 

126 mean, _, nvar = tl.reduce((average, count, acc), axis=0, combine_fn=welford_func) 

127 

128 var = nvar / (N - correction) 

129 tl.store(Mean, mean) 

130 tl.store(Var, var) 

131 

132 

133def var_mean(x, dim=None, *, correction=None, keepdim=False): 

134 logger.debug("GEMS VAR MEAN") 

135 if correction is None: 

136 correction = 1.0 

137 

138 if dim is None or len(dim) == x.ndim: 

139 dim = list(range(x.ndim)) 

140 shape = [1] * x.ndim 

141 N = x.numel() 

142 var = torch.empty(shape, dtype=x.dtype, device=x.device) 

143 mean = torch.empty(shape, dtype=x.dtype, device=x.device) 

144 BLOCK_N = 1024 

145 BLOCK_NUM = triton.cdiv(N, BLOCK_N) 

146 acc = torch.empty([BLOCK_NUM], dtype=x.dtype, device=x.device) 

147 average = torch.empty([BLOCK_NUM], dtype=x.dtype, device=x.device) 

148 count = torch.empty([BLOCK_NUM], dtype=x.dtype, device=x.device) 

149 

150 with torch_device_fn.device(x.device): 

151 var_mean_kernel_1[(BLOCK_NUM,)](x, acc, average, count, N, BLOCK_N=BLOCK_N) 

152 var_mean_kernel_2[(1,)]( 

153 acc, average, count, var, mean, N, correction, BLOCK_NUM 

154 ) 

155 else: 

156 shape = list(x.shape) 

157 dim = [d % x.ndim for d in dim] 

158 x = dim_compress(x, dim) 

159 N = 1 

160 for i in dim: 

161 N *= shape[i] 

162 shape[i] = 1 

163 M = x.numel() // N 

164 var = torch.empty(shape, dtype=x.dtype, device=x.device) 

165 mean = torch.empty(shape, dtype=x.dtype, device=x.device) 

166 

167 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),) 

168 with torch_device_fn.device(x.device): 

169 var_mean_welford_kernel[grid](x, var, mean, M, N, correction) 

170 

171 if not keepdim: 

172 var = var.squeeze(dim=dim) 

173 mean = mean.squeeze(dim=dim) 

174 return var, mean