Coverage for src/flag_gems/runtime/backend/_metax/ops/groupnorm.py: 0%
147 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import libentry, tl_extra_shim
9from flag_gems.utils import triton_lang_extension as tle
11rsqrt = tl_extra_shim.rsqrt
13logger = logging.getLogger("flag_gems." + __name__)
16@libentry()
17@triton.jit(do_not_specialize=["eps"])
18def group_norm_kernel(
19 X,
20 Y,
21 W,
22 B,
23 Mean,
24 Rstd,
25 group_size,
26 C,
27 HW,
28 num_groups,
29 eps,
30 BLOCK_GROUP_SIZE: tl.constexpr,
31 BLOCK_HW_SIZE: tl.constexpr,
32):
33 pid = tle.program_id(0)
34 group = pid % num_groups
35 num_elements = group_size * HW
36 group_offset = tl.arange(0, BLOCK_GROUP_SIZE)
37 hw_offset = tl.arange(0, BLOCK_HW_SIZE)
39 wb_offset = group * group_size + group_offset
40 wb_mask = wb_offset < C
42 xy_offset = pid * num_elements + group_offset[:, None] * HW + hw_offset[None, :]
43 xy_mask = wb_offset[:, None] < C and hw_offset[None, :] < HW
45 Mean_ptr = Mean + pid
46 Rstd_ptr = Rstd + pid
48 X_ptr = X + xy_offset
49 Y_ptr = Y + xy_offset
51 X_val = tl.load(X_ptr, mask=xy_mask, other=0.0).to(tl.float32)
52 mean = tl.sum(X_val) / num_elements
53 x = tl.where(xy_mask, X_val - mean, 0.0)
55 var = tl.sum(x * x) / num_elements
56 rstd = rsqrt(var + eps)
57 x_hat = x * rstd
59 if W is None:
60 weight = 1
61 else:
62 weight = tl.load(W + wb_offset, mask=wb_mask, other=0.0)[:, None]
63 if B is None:
64 bias = 0
65 else:
66 bias = tl.load(B + wb_offset, mask=wb_mask, other=0.0)[:, None]
67 Y_val = x_hat * weight + bias
69 tl.store(Y_ptr, Y_val, mask=xy_mask)
70 tl.store(Mean_ptr, mean)
71 tl.store(Rstd_ptr, rstd)
74@libentry()
75@triton.jit
76def group_norm_backward_kernel(
77 grad_y,
78 X,
79 W,
80 Mean,
81 Rstd,
82 num_groups,
83 group_size,
84 grad_x,
85 C,
86 HW,
87 BLOCK_GROUP_SIZE: tl.constexpr,
88 BLOCK_HW_SIZE: tl.constexpr,
89):
90 pid = tle.program_id(0)
91 group = pid % num_groups
92 num_elements = group_size * BLOCK_HW_SIZE
94 group_offset = tl.arange(0, BLOCK_GROUP_SIZE)
95 hw_offset = tl.arange(0, BLOCK_HW_SIZE)
96 wb_offset = group * group_size + group_offset
98 wb_mask = wb_offset < C
100 xy_offset = pid * num_elements + group_offset[:, None] * HW + hw_offset[None, :]
101 xy_mask = wb_offset[:, None] < C and hw_offset[None, :] < HW
103 Mean_ptr = Mean + pid
104 Rstd_ptr = Rstd + pid
105 X_ptr = X + xy_offset
106 dY_ptr = grad_y + xy_offset
107 dX_ptr = grad_x + xy_offset
109 rstd = tl.load(Rstd_ptr).to(tl.float32)
110 mean = tl.load(Mean_ptr).to(tl.float32)
111 dY_val = tl.load(dY_ptr, mask=xy_mask, other=0.0).to(tl.float32)
112 X_val = tl.load(X_ptr, mask=xy_mask, other=0.0).to(tl.float32)
114 if W is None:
115 weight = 1
116 else:
117 weight = tl.load(W + wb_offset, mask=wb_mask, other=0.0).to(tl.float32)[:, None]
119 dx_hat = weight * dY_val
121 x = tl.where(xy_mask, X_val - mean, 0.0)
123 grad_std = tl.sum(dx_hat * x)
124 grad_var = grad_std * -(0.5 * rstd * rstd * rstd) / (HW * group_size)
125 grad_distance = 2 * x * grad_var
126 grad_centered_mean = dx_hat * rstd + grad_distance
127 grad_mean = -tl.sum(grad_centered_mean) / num_elements
128 grad_X = grad_centered_mean + grad_mean
129 tl.store(dX_ptr, grad_X, mask=xy_mask)
132@libentry()
133@triton.jit
134def weight_bias_backward_kernel(
135 dY,
136 X,
137 Mean,
138 Rstd,
139 dW,
140 dB,
141 num_groups,
142 group_size,
143 N,
144 C,
145 HW,
146 BLOCK_N: tl.constexpr,
147 BLOCK_HW: tl.constexpr,
148):
149 pid = tle.program_id(0)
150 group = pid // group_size
151 n_offset = tl.arange(0, BLOCK_N)
152 hw_offset = tl.arange(0, BLOCK_HW)
153 xy_mask = n_offset[:, None] < N and hw_offset[None, :] < HW
154 mr_mask = n_offset < N
156 mean_ptr = Mean + group + n_offset * num_groups
157 rstd_ptr = Rstd + group + n_offset * num_groups
159 dY_ptr = dY + pid * BLOCK_HW + n_offset[:, None] * C * HW + hw_offset[None, :]
160 x_ptr = X + pid * BLOCK_HW + n_offset[:, None] * C * HW + hw_offset[None, :]
162 grad_y = tl.load(dY_ptr, mask=xy_mask, other=0.0).to(tl.float32)
163 x = tl.load(x_ptr, mask=xy_mask, other=0.0)
164 x_f32 = x.to(tl.float32)
165 mean = tl.load(mean_ptr, mask=mr_mask, other=0.0).to(tl.float32)[:, None]
166 rstd = tl.load(rstd_ptr, mask=mr_mask, other=0.0).to(tl.float32)[:, None]
168 if dW is not None:
169 dw = tl.sum((x_f32 - mean) * rstd * grad_y, 1)
170 dw = tl.sum(dw)
171 tl.store(dW + pid, dw.to(x.dtype))
172 if dB is not None:
173 db = tl.sum(grad_y, 1)
174 db = tl.sum(db)
175 tl.store(dB + pid, db.to(x.dtype))
178class GroupNorm(torch.autograd.Function):
179 @staticmethod
180 def forward(ctx, x, N, C, HW, num_groups, weight=None, bias=None, eps=1e-05):
181 logger.debug("METAX GEMS GROUPNORM FORWARD")
182 group_size = C // num_groups
183 x = x.contiguous()
184 if weight is not None:
185 weight = weight.contiguous()
186 if bias is not None:
187 bias = bias.contiguous()
188 y = torch.empty_like(x)
189 mean = torch.empty((N, num_groups), dtype=x.dtype, device=x.device)
190 rstd = torch.empty((N, num_groups), dtype=x.dtype, device=x.device)
191 grid = (N * num_groups,)
193 with torch_device_fn.device(x.device):
194 group_norm_kernel[grid](
195 x,
196 y,
197 weight,
198 bias,
199 mean,
200 rstd,
201 group_size,
202 C,
203 HW,
204 num_groups,
205 eps,
206 BLOCK_GROUP_SIZE=triton.next_power_of_2(C // num_groups),
207 BLOCK_HW_SIZE=triton.next_power_of_2(HW),
208 )
209 if x.requires_grad:
210 ctx.save_for_backward(x, weight, bias, mean, rstd)
211 ctx.num_groups = num_groups
212 ctx.group_size = group_size
213 ctx.N = N
214 ctx.C = C
215 ctx.HW = HW
216 return y, mean, rstd
218 @staticmethod
219 def backward(ctx, y_grad, mean_grad, rstd_grad):
220 logger.debug("METAX GEMS GROUPNORM BACKWARD")
221 y_grad = y_grad.contiguous()
222 (x, weight, bias, mean, rstd) = ctx.saved_tensors
223 num_groups = ctx.num_groups
224 group_size = ctx.group_size
225 N = ctx.N
226 C = ctx.C
227 HW = ctx.HW
228 x_grad = torch.empty_like(x)
229 grid = (N * num_groups,)
230 with torch_device_fn.device(x.device):
231 group_norm_backward_kernel[grid](
232 y_grad,
233 x,
234 weight,
235 mean,
236 rstd,
237 num_groups,
238 group_size,
239 x_grad,
240 C,
241 HW,
242 BLOCK_GROUP_SIZE=triton.next_power_of_2(C // num_groups),
243 BLOCK_HW_SIZE=triton.next_power_of_2(HW),
244 )
245 if weight is None and bias is None:
246 return x_grad, None, None, None, None, None, None, None
248 weight_grad = None if weight is None else torch.empty_like(weight)
249 bias_grad = None if bias is None else torch.empty_like(bias)
250 with torch_device_fn.device(x.device):
251 weight_bias_backward_kernel[(C, 1, 1)](
252 y_grad,
253 x,
254 mean,
255 rstd,
256 weight_grad,
257 bias_grad,
258 num_groups,
259 group_size,
260 N,
261 C,
262 HW,
263 BLOCK_N=triton.next_power_of_2(N),
264 BLOCK_HW=triton.next_power_of_2(HW),
265 )
266 return x_grad, None, None, None, None, weight_grad, bias_grad, None
269def group_norm(x, weight, bias, N, C, HW, num_groups, eps):
270 return GroupNorm.apply(x, N, C, HW, num_groups, weight, bias, eps)