Coverage for src/flag_gems/runtime/backend/_metax/ops/groupnorm.py: 0%

147 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import libentry, tl_extra_shim 

9from flag_gems.utils import triton_lang_extension as tle 

10 

11rsqrt = tl_extra_shim.rsqrt 

12 

13logger = logging.getLogger("flag_gems." + __name__) 

14 

15 

16@libentry() 

17@triton.jit(do_not_specialize=["eps"]) 

18def group_norm_kernel( 

19 X, 

20 Y, 

21 W, 

22 B, 

23 Mean, 

24 Rstd, 

25 group_size, 

26 C, 

27 HW, 

28 num_groups, 

29 eps, 

30 BLOCK_GROUP_SIZE: tl.constexpr, 

31 BLOCK_HW_SIZE: tl.constexpr, 

32): 

33 pid = tle.program_id(0) 

34 group = pid % num_groups 

35 num_elements = group_size * HW 

36 group_offset = tl.arange(0, BLOCK_GROUP_SIZE) 

37 hw_offset = tl.arange(0, BLOCK_HW_SIZE) 

38 

39 wb_offset = group * group_size + group_offset 

40 wb_mask = wb_offset < C 

41 

42 xy_offset = pid * num_elements + group_offset[:, None] * HW + hw_offset[None, :] 

43 xy_mask = wb_offset[:, None] < C and hw_offset[None, :] < HW 

44 

45 Mean_ptr = Mean + pid 

46 Rstd_ptr = Rstd + pid 

47 

48 X_ptr = X + xy_offset 

49 Y_ptr = Y + xy_offset 

50 

51 X_val = tl.load(X_ptr, mask=xy_mask, other=0.0).to(tl.float32) 

52 mean = tl.sum(X_val) / num_elements 

53 x = tl.where(xy_mask, X_val - mean, 0.0) 

54 

55 var = tl.sum(x * x) / num_elements 

56 rstd = rsqrt(var + eps) 

57 x_hat = x * rstd 

58 

59 if W is None: 

60 weight = 1 

61 else: 

62 weight = tl.load(W + wb_offset, mask=wb_mask, other=0.0)[:, None] 

63 if B is None: 

64 bias = 0 

65 else: 

66 bias = tl.load(B + wb_offset, mask=wb_mask, other=0.0)[:, None] 

67 Y_val = x_hat * weight + bias 

68 

69 tl.store(Y_ptr, Y_val, mask=xy_mask) 

70 tl.store(Mean_ptr, mean) 

71 tl.store(Rstd_ptr, rstd) 

72 

73 

74@libentry() 

75@triton.jit 

76def group_norm_backward_kernel( 

77 grad_y, 

78 X, 

79 W, 

80 Mean, 

81 Rstd, 

82 num_groups, 

83 group_size, 

84 grad_x, 

85 C, 

86 HW, 

87 BLOCK_GROUP_SIZE: tl.constexpr, 

88 BLOCK_HW_SIZE: tl.constexpr, 

89): 

90 pid = tle.program_id(0) 

91 group = pid % num_groups 

92 num_elements = group_size * BLOCK_HW_SIZE 

93 

94 group_offset = tl.arange(0, BLOCK_GROUP_SIZE) 

95 hw_offset = tl.arange(0, BLOCK_HW_SIZE) 

96 wb_offset = group * group_size + group_offset 

97 

98 wb_mask = wb_offset < C 

99 

100 xy_offset = pid * num_elements + group_offset[:, None] * HW + hw_offset[None, :] 

101 xy_mask = wb_offset[:, None] < C and hw_offset[None, :] < HW 

102 

103 Mean_ptr = Mean + pid 

104 Rstd_ptr = Rstd + pid 

105 X_ptr = X + xy_offset 

106 dY_ptr = grad_y + xy_offset 

107 dX_ptr = grad_x + xy_offset 

108 

109 rstd = tl.load(Rstd_ptr).to(tl.float32) 

110 mean = tl.load(Mean_ptr).to(tl.float32) 

111 dY_val = tl.load(dY_ptr, mask=xy_mask, other=0.0).to(tl.float32) 

112 X_val = tl.load(X_ptr, mask=xy_mask, other=0.0).to(tl.float32) 

113 

114 if W is None: 

115 weight = 1 

116 else: 

117 weight = tl.load(W + wb_offset, mask=wb_mask, other=0.0).to(tl.float32)[:, None] 

118 

119 dx_hat = weight * dY_val 

120 

121 x = tl.where(xy_mask, X_val - mean, 0.0) 

122 

123 grad_std = tl.sum(dx_hat * x) 

124 grad_var = grad_std * -(0.5 * rstd * rstd * rstd) / (HW * group_size) 

125 grad_distance = 2 * x * grad_var 

126 grad_centered_mean = dx_hat * rstd + grad_distance 

127 grad_mean = -tl.sum(grad_centered_mean) / num_elements 

128 grad_X = grad_centered_mean + grad_mean 

129 tl.store(dX_ptr, grad_X, mask=xy_mask) 

130 

131 

132@libentry() 

133@triton.jit 

134def weight_bias_backward_kernel( 

135 dY, 

136 X, 

137 Mean, 

138 Rstd, 

139 dW, 

140 dB, 

141 num_groups, 

142 group_size, 

143 N, 

144 C, 

145 HW, 

146 BLOCK_N: tl.constexpr, 

147 BLOCK_HW: tl.constexpr, 

148): 

149 pid = tle.program_id(0) 

150 group = pid // group_size 

151 n_offset = tl.arange(0, BLOCK_N) 

152 hw_offset = tl.arange(0, BLOCK_HW) 

153 xy_mask = n_offset[:, None] < N and hw_offset[None, :] < HW 

154 mr_mask = n_offset < N 

155 

156 mean_ptr = Mean + group + n_offset * num_groups 

157 rstd_ptr = Rstd + group + n_offset * num_groups 

158 

159 dY_ptr = dY + pid * BLOCK_HW + n_offset[:, None] * C * HW + hw_offset[None, :] 

160 x_ptr = X + pid * BLOCK_HW + n_offset[:, None] * C * HW + hw_offset[None, :] 

161 

162 grad_y = tl.load(dY_ptr, mask=xy_mask, other=0.0).to(tl.float32) 

163 x = tl.load(x_ptr, mask=xy_mask, other=0.0) 

164 x_f32 = x.to(tl.float32) 

165 mean = tl.load(mean_ptr, mask=mr_mask, other=0.0).to(tl.float32)[:, None] 

166 rstd = tl.load(rstd_ptr, mask=mr_mask, other=0.0).to(tl.float32)[:, None] 

167 

168 if dW is not None: 

169 dw = tl.sum((x_f32 - mean) * rstd * grad_y, 1) 

170 dw = tl.sum(dw) 

171 tl.store(dW + pid, dw.to(x.dtype)) 

172 if dB is not None: 

173 db = tl.sum(grad_y, 1) 

174 db = tl.sum(db) 

175 tl.store(dB + pid, db.to(x.dtype)) 

176 

177 

178class GroupNorm(torch.autograd.Function): 

179 @staticmethod 

180 def forward(ctx, x, N, C, HW, num_groups, weight=None, bias=None, eps=1e-05): 

181 logger.debug("METAX GEMS GROUPNORM FORWARD") 

182 group_size = C // num_groups 

183 x = x.contiguous() 

184 if weight is not None: 

185 weight = weight.contiguous() 

186 if bias is not None: 

187 bias = bias.contiguous() 

188 y = torch.empty_like(x) 

189 mean = torch.empty((N, num_groups), dtype=x.dtype, device=x.device) 

190 rstd = torch.empty((N, num_groups), dtype=x.dtype, device=x.device) 

191 grid = (N * num_groups,) 

192 

193 with torch_device_fn.device(x.device): 

194 group_norm_kernel[grid]( 

195 x, 

196 y, 

197 weight, 

198 bias, 

199 mean, 

200 rstd, 

201 group_size, 

202 C, 

203 HW, 

204 num_groups, 

205 eps, 

206 BLOCK_GROUP_SIZE=triton.next_power_of_2(C // num_groups), 

207 BLOCK_HW_SIZE=triton.next_power_of_2(HW), 

208 ) 

209 if x.requires_grad: 

210 ctx.save_for_backward(x, weight, bias, mean, rstd) 

211 ctx.num_groups = num_groups 

212 ctx.group_size = group_size 

213 ctx.N = N 

214 ctx.C = C 

215 ctx.HW = HW 

216 return y, mean, rstd 

217 

218 @staticmethod 

219 def backward(ctx, y_grad, mean_grad, rstd_grad): 

220 logger.debug("METAX GEMS GROUPNORM BACKWARD") 

221 y_grad = y_grad.contiguous() 

222 (x, weight, bias, mean, rstd) = ctx.saved_tensors 

223 num_groups = ctx.num_groups 

224 group_size = ctx.group_size 

225 N = ctx.N 

226 C = ctx.C 

227 HW = ctx.HW 

228 x_grad = torch.empty_like(x) 

229 grid = (N * num_groups,) 

230 with torch_device_fn.device(x.device): 

231 group_norm_backward_kernel[grid]( 

232 y_grad, 

233 x, 

234 weight, 

235 mean, 

236 rstd, 

237 num_groups, 

238 group_size, 

239 x_grad, 

240 C, 

241 HW, 

242 BLOCK_GROUP_SIZE=triton.next_power_of_2(C // num_groups), 

243 BLOCK_HW_SIZE=triton.next_power_of_2(HW), 

244 ) 

245 if weight is None and bias is None: 

246 return x_grad, None, None, None, None, None, None, None 

247 

248 weight_grad = None if weight is None else torch.empty_like(weight) 

249 bias_grad = None if bias is None else torch.empty_like(bias) 

250 with torch_device_fn.device(x.device): 

251 weight_bias_backward_kernel[(C, 1, 1)]( 

252 y_grad, 

253 x, 

254 mean, 

255 rstd, 

256 weight_grad, 

257 bias_grad, 

258 num_groups, 

259 group_size, 

260 N, 

261 C, 

262 HW, 

263 BLOCK_N=triton.next_power_of_2(N), 

264 BLOCK_HW=triton.next_power_of_2(HW), 

265 ) 

266 return x_grad, None, None, None, None, weight_grad, bias_grad, None 

267 

268 

269def group_norm(x, weight, bias, N, C, HW, num_groups, eps): 

270 return GroupNorm.apply(x, N, C, HW, num_groups, weight, bias, eps)