Coverage for src/flag_gems/ops/groupnorm.py: 37%

139 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-13 10:08 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import libentry, tl_extra_shim 

9from flag_gems.utils import triton_lang_extension as tle 

10 

11rsqrt = tl_extra_shim.rsqrt 

12logger = logging.getLogger(__name__) 

13 

14 

15@libentry() 

16@triton.jit(do_not_specialize=["eps"]) 

17def group_norm_kernel( 

18 X, 

19 Y, 

20 W, 

21 B, 

22 Mean, 

23 Rstd, 

24 group_size, 

25 C, 

26 HW, 

27 num_groups, 

28 eps, 

29 BLOCK_GROUP_SIZE: tl.constexpr, 

30 BLOCK_HW_SIZE: tl.constexpr, 

31): 

32 pid = tle.program_id(0) 

33 group = pid % num_groups 

34 num_elements = group_size * HW 

35 group_offset = tl.arange(0, BLOCK_GROUP_SIZE) 

36 hw_offset = tl.arange(0, BLOCK_HW_SIZE) 

37 

38 wb_offset = group * group_size + group_offset 

39 wb_mask = wb_offset < C 

40 

41 xy_offset = pid * num_elements + group_offset[:, None] * HW + hw_offset[None, :] 

42 xy_mask = wb_offset[:, None] < C and hw_offset[None, :] < HW 

43 

44 Mean_ptr = Mean + pid 

45 Rstd_ptr = Rstd + pid 

46 

47 X_ptr = X + xy_offset 

48 Y_ptr = Y + xy_offset 

49 

50 X_val = tl.load(X_ptr, mask=xy_mask, other=0.0).to(tl.float32) 

51 mean = tl.sum(X_val) / num_elements 

52 x = tl.where(xy_mask, X_val - mean, 0.0) 

53 

54 var = tl.sum(x * x) / num_elements 

55 rstd = rsqrt(var + eps) 

56 x_hat = x * rstd 

57 

58 if W is None: 

59 weight = 1 

60 else: 

61 weight = tl.load(W + wb_offset, mask=wb_mask, other=0.0)[:, None] 

62 if B is None: 

63 bias = 0 

64 else: 

65 bias = tl.load(B + wb_offset, mask=wb_mask, other=0.0)[:, None] 

66 Y_val = x_hat * weight + bias 

67 

68 tl.store(Y_ptr, Y_val, mask=xy_mask) 

69 tl.store(Mean_ptr, mean) 

70 tl.store(Rstd_ptr, rstd) 

71 

72 

73@libentry() 

74@triton.jit 

75def group_norm_backward_kernel( 

76 grad_y, 

77 X, 

78 W, 

79 Mean, 

80 Rstd, 

81 num_groups, 

82 group_size, 

83 grad_x, 

84 C, 

85 HW, 

86 BLOCK_GROUP_SIZE: tl.constexpr, 

87 BLOCK_HW_SIZE: tl.constexpr = 128, 

88): 

89 pid = tle.program_id(0) 

90 group = pid % num_groups 

91 num_elements = group_size * HW 

92 

93 group_offset = tl.arange(0, BLOCK_GROUP_SIZE) 

94 wb_offset = group * group_size + group_offset 

95 

96 wb_mask = wb_offset < C 

97 

98 rstd = tl.load(Rstd + pid).to(tl.float32) 

99 mean = tl.load(Mean + pid).to(tl.float32) 

100 if W is None: 

101 weight = 1 

102 else: 

103 weight = tl.load(W + wb_offset, mask=wb_mask, other=0.0).to(tl.float32)[:, None] 

104 

105 dx_part2 = tl.zeros([BLOCK_GROUP_SIZE, BLOCK_HW_SIZE], dtype=tl.float32) 

106 dx_part3 = tl.zeros([BLOCK_GROUP_SIZE, BLOCK_HW_SIZE], dtype=tl.float32) 

107 for off in range(0, HW, BLOCK_HW_SIZE): 

108 hw_offset = off + tl.arange(0, BLOCK_HW_SIZE) 

109 hw_mask = hw_offset < HW 

110 xy_offset = pid * num_elements + group_offset[:, None] * HW + hw_offset[None, :] 

111 xy_mask = wb_mask[:, None] & hw_mask[None, :] 

112 

113 dY_val = tl.load(grad_y + xy_offset, mask=xy_mask, other=0.0).to(tl.float32) 

114 X_val = tl.load(X + xy_offset, mask=xy_mask, other=0.0).to(tl.float32) 

115 

116 x_hat = tl.where(xy_mask, rstd * (X_val - mean), 0.0) 

117 dx_hat = weight * dY_val 

118 dx_part2 += dx_hat 

119 dx_part3 += dx_hat * x_hat 

120 

121 dx_2 = tl.sum(dx_part2) 

122 dx_3 = tl.sum(dx_part3) 

123 

124 for off in range(0, HW, BLOCK_HW_SIZE): 

125 hw_offset = off + tl.arange(0, BLOCK_HW_SIZE) 

126 hw_mask = hw_offset < HW 

127 xy_offset = pid * num_elements + group_offset[:, None] * HW + hw_offset[None, :] 

128 xy_mask = wb_mask[:, None] & hw_mask[None, :] 

129 

130 dY_val = tl.load(grad_y + xy_offset, mask=xy_mask, other=0.0).to(tl.float32) 

131 X_val = tl.load(X + xy_offset, mask=xy_mask, other=0.0).to(tl.float32) 

132 

133 x_hat = tl.where(xy_mask, rstd * (X_val - mean), 0.0) 

134 dx_hat = weight * dY_val 

135 dx = rstd * (dx_hat - (dx_2 + x_hat * dx_3) / num_elements) 

136 

137 tl.store(grad_x + xy_offset, dx, xy_mask) 

138 

139 

140@libentry() 

141@triton.jit 

142def weight_bias_backward_kernel( 

143 dY, 

144 X, 

145 Mean, 

146 Rstd, 

147 dW, 

148 dB, 

149 num_groups, 

150 group_size, 

151 N, 

152 C, 

153 HW, 

154 BLOCK_N: tl.constexpr, 

155 BLOCK_HW: tl.constexpr, 

156): 

157 pid = tle.program_id(0) 

158 group = pid // group_size 

159 n_offset = tl.arange(0, BLOCK_N) 

160 hw_offset = tl.arange(0, BLOCK_HW) 

161 xy_mask = n_offset[:, None] < N and hw_offset[None, :] < HW 

162 mr_mask = n_offset < N 

163 

164 mean_ptr = Mean + group + n_offset * num_groups 

165 rstd_ptr = Rstd + group + n_offset * num_groups 

166 

167 dY_ptr = dY + pid * HW + n_offset[:, None] * C * HW + hw_offset[None, :] 

168 x_ptr = X + pid * HW + n_offset[:, None] * C * HW + hw_offset[None, :] 

169 

170 grad_y = tl.load(dY_ptr, mask=xy_mask, other=0.0).to(tl.float32) 

171 x = tl.load(x_ptr, mask=xy_mask, other=0.0) 

172 x_f32 = x.to(tl.float32) 

173 mean = tl.load(mean_ptr, mask=mr_mask, other=0.0).to(tl.float32)[:, None] 

174 rstd = tl.load(rstd_ptr, mask=mr_mask, other=0.0).to(tl.float32)[:, None] 

175 

176 if dW is not None: 

177 dw = tl.sum((x_f32 - mean) * rstd * grad_y) 

178 tl.store(dW + pid, dw) 

179 if dB is not None: 

180 db = tl.sum(grad_y) 

181 tl.store(dB + pid, db) 

182 

183 

184def group_norm(input, weight, bias, N, C, HxW, group, eps=1e-05): 

185 logger.debug("GEMS GROUPNORM FORWARD") 

186 

187 group_size = triton.cdiv(C, group) 

188 input = input.contiguous() 

189 weight = None if weight is None else weight.contiguous() 

190 bias = None if bias is None else bias.contiguous() 

191 

192 y = torch.empty_like(input) 

193 mean = torch.empty((N, group), dtype=input.dtype, device=input.device) 

194 rstd = torch.empty((N, group), dtype=input.dtype, device=input.device) 

195 

196 grid = (N * group,) 

197 with torch_device_fn.device(input.device): 

198 group_norm_kernel[grid]( 

199 input, 

200 y, 

201 weight, 

202 bias, 

203 mean, 

204 rstd, 

205 group_size, 

206 C, 

207 HxW, 

208 group, 

209 eps, 

210 BLOCK_GROUP_SIZE=triton.next_power_of_2(group_size), 

211 BLOCK_HW_SIZE=triton.next_power_of_2(HxW), 

212 ) 

213 return y, mean, rstd 

214 

215 

216def group_norm_backward( 

217 grad_out, input, mean, rstd, weight, N, C, HxW, group, output_mask 

218): 

219 logger.debug("GEMS GROUPNORM BACKWARD") 

220 

221 grad_out = grad_out.contiguous() 

222 input = input.contiguous() 

223 mean = mean.contiguous() 

224 rstd = rstd.contiguous() 

225 weight = None if weight is None else weight.contiguous() 

226 group_size = triton.cdiv(C, group) 

227 

228 if output_mask[0]: 

229 grad_inp = torch.empty_like(input) 

230 grid = (N * group,) 

231 with torch_device_fn.device(input.device): 

232 group_norm_backward_kernel[grid]( 

233 grad_out, 

234 input, 

235 weight, 

236 mean, 

237 rstd, 

238 group, 

239 group_size, 

240 grad_inp, 

241 C, 

242 HxW, 

243 BLOCK_GROUP_SIZE=triton.next_power_of_2(group_size), 

244 ) 

245 else: 

246 grad_inp = None 

247 

248 if output_mask[1] is False and output_mask[2] is False: 

249 return grad_inp, None, None 

250 

251 weight_grad = torch.empty_like(weight) if output_mask[1] else None 

252 bias_grad = torch.empty_like(weight) if output_mask[2] else None 

253 with torch_device_fn.device(input.device): 

254 weight_bias_backward_kernel[(C, 1, 1)]( 

255 grad_out, 

256 input, 

257 mean, 

258 rstd, 

259 weight_grad, 

260 bias_grad, 

261 group, 

262 group_size, 

263 N, 

264 C, 

265 HxW, 

266 BLOCK_N=triton.next_power_of_2(N), 

267 BLOCK_HW=triton.next_power_of_2(HxW), 

268 ) 

269 return grad_inp, weight_grad, bias_grad