Coverage for src/flag_gems/ops/groupnorm.py: 37%
139 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import libentry, tl_extra_shim
9from flag_gems.utils import triton_lang_extension as tle
11rsqrt = tl_extra_shim.rsqrt
12logger = logging.getLogger(__name__)
15@libentry()
16@triton.jit(do_not_specialize=["eps"])
17def group_norm_kernel(
18 X,
19 Y,
20 W,
21 B,
22 Mean,
23 Rstd,
24 group_size,
25 C,
26 HW,
27 num_groups,
28 eps,
29 BLOCK_GROUP_SIZE: tl.constexpr,
30 BLOCK_HW_SIZE: tl.constexpr,
31):
32 pid = tle.program_id(0)
33 group = pid % num_groups
34 num_elements = group_size * HW
35 group_offset = tl.arange(0, BLOCK_GROUP_SIZE)
36 hw_offset = tl.arange(0, BLOCK_HW_SIZE)
38 wb_offset = group * group_size + group_offset
39 wb_mask = wb_offset < C
41 xy_offset = pid * num_elements + group_offset[:, None] * HW + hw_offset[None, :]
42 xy_mask = wb_offset[:, None] < C and hw_offset[None, :] < HW
44 Mean_ptr = Mean + pid
45 Rstd_ptr = Rstd + pid
47 X_ptr = X + xy_offset
48 Y_ptr = Y + xy_offset
50 X_val = tl.load(X_ptr, mask=xy_mask, other=0.0).to(tl.float32)
51 mean = tl.sum(X_val) / num_elements
52 x = tl.where(xy_mask, X_val - mean, 0.0)
54 var = tl.sum(x * x) / num_elements
55 rstd = rsqrt(var + eps)
56 x_hat = x * rstd
58 if W is None:
59 weight = 1
60 else:
61 weight = tl.load(W + wb_offset, mask=wb_mask, other=0.0)[:, None]
62 if B is None:
63 bias = 0
64 else:
65 bias = tl.load(B + wb_offset, mask=wb_mask, other=0.0)[:, None]
66 Y_val = x_hat * weight + bias
68 tl.store(Y_ptr, Y_val, mask=xy_mask)
69 tl.store(Mean_ptr, mean)
70 tl.store(Rstd_ptr, rstd)
73@libentry()
74@triton.jit
75def group_norm_backward_kernel(
76 grad_y,
77 X,
78 W,
79 Mean,
80 Rstd,
81 num_groups,
82 group_size,
83 grad_x,
84 C,
85 HW,
86 BLOCK_GROUP_SIZE: tl.constexpr,
87 BLOCK_HW_SIZE: tl.constexpr = 128,
88):
89 pid = tle.program_id(0)
90 group = pid % num_groups
91 num_elements = group_size * HW
93 group_offset = tl.arange(0, BLOCK_GROUP_SIZE)
94 wb_offset = group * group_size + group_offset
96 wb_mask = wb_offset < C
98 rstd = tl.load(Rstd + pid).to(tl.float32)
99 mean = tl.load(Mean + pid).to(tl.float32)
100 if W is None:
101 weight = 1
102 else:
103 weight = tl.load(W + wb_offset, mask=wb_mask, other=0.0).to(tl.float32)[:, None]
105 dx_part2 = tl.zeros([BLOCK_GROUP_SIZE, BLOCK_HW_SIZE], dtype=tl.float32)
106 dx_part3 = tl.zeros([BLOCK_GROUP_SIZE, BLOCK_HW_SIZE], dtype=tl.float32)
107 for off in range(0, HW, BLOCK_HW_SIZE):
108 hw_offset = off + tl.arange(0, BLOCK_HW_SIZE)
109 hw_mask = hw_offset < HW
110 xy_offset = pid * num_elements + group_offset[:, None] * HW + hw_offset[None, :]
111 xy_mask = wb_mask[:, None] & hw_mask[None, :]
113 dY_val = tl.load(grad_y + xy_offset, mask=xy_mask, other=0.0).to(tl.float32)
114 X_val = tl.load(X + xy_offset, mask=xy_mask, other=0.0).to(tl.float32)
116 x_hat = tl.where(xy_mask, rstd * (X_val - mean), 0.0)
117 dx_hat = weight * dY_val
118 dx_part2 += dx_hat
119 dx_part3 += dx_hat * x_hat
121 dx_2 = tl.sum(dx_part2)
122 dx_3 = tl.sum(dx_part3)
124 for off in range(0, HW, BLOCK_HW_SIZE):
125 hw_offset = off + tl.arange(0, BLOCK_HW_SIZE)
126 hw_mask = hw_offset < HW
127 xy_offset = pid * num_elements + group_offset[:, None] * HW + hw_offset[None, :]
128 xy_mask = wb_mask[:, None] & hw_mask[None, :]
130 dY_val = tl.load(grad_y + xy_offset, mask=xy_mask, other=0.0).to(tl.float32)
131 X_val = tl.load(X + xy_offset, mask=xy_mask, other=0.0).to(tl.float32)
133 x_hat = tl.where(xy_mask, rstd * (X_val - mean), 0.0)
134 dx_hat = weight * dY_val
135 dx = rstd * (dx_hat - (dx_2 + x_hat * dx_3) / num_elements)
137 tl.store(grad_x + xy_offset, dx, xy_mask)
140@libentry()
141@triton.jit
142def weight_bias_backward_kernel(
143 dY,
144 X,
145 Mean,
146 Rstd,
147 dW,
148 dB,
149 num_groups,
150 group_size,
151 N,
152 C,
153 HW,
154 BLOCK_N: tl.constexpr,
155 BLOCK_HW: tl.constexpr,
156):
157 pid = tle.program_id(0)
158 group = pid // group_size
159 n_offset = tl.arange(0, BLOCK_N)
160 hw_offset = tl.arange(0, BLOCK_HW)
161 xy_mask = n_offset[:, None] < N and hw_offset[None, :] < HW
162 mr_mask = n_offset < N
164 mean_ptr = Mean + group + n_offset * num_groups
165 rstd_ptr = Rstd + group + n_offset * num_groups
167 dY_ptr = dY + pid * HW + n_offset[:, None] * C * HW + hw_offset[None, :]
168 x_ptr = X + pid * HW + n_offset[:, None] * C * HW + hw_offset[None, :]
170 grad_y = tl.load(dY_ptr, mask=xy_mask, other=0.0).to(tl.float32)
171 x = tl.load(x_ptr, mask=xy_mask, other=0.0)
172 x_f32 = x.to(tl.float32)
173 mean = tl.load(mean_ptr, mask=mr_mask, other=0.0).to(tl.float32)[:, None]
174 rstd = tl.load(rstd_ptr, mask=mr_mask, other=0.0).to(tl.float32)[:, None]
176 if dW is not None:
177 dw = tl.sum((x_f32 - mean) * rstd * grad_y)
178 tl.store(dW + pid, dw)
179 if dB is not None:
180 db = tl.sum(grad_y)
181 tl.store(dB + pid, db)
184def group_norm(input, weight, bias, N, C, HxW, group, eps=1e-05):
185 logger.debug("GEMS GROUPNORM FORWARD")
187 group_size = triton.cdiv(C, group)
188 input = input.contiguous()
189 weight = None if weight is None else weight.contiguous()
190 bias = None if bias is None else bias.contiguous()
192 y = torch.empty_like(input)
193 mean = torch.empty((N, group), dtype=input.dtype, device=input.device)
194 rstd = torch.empty((N, group), dtype=input.dtype, device=input.device)
196 grid = (N * group,)
197 with torch_device_fn.device(input.device):
198 group_norm_kernel[grid](
199 input,
200 y,
201 weight,
202 bias,
203 mean,
204 rstd,
205 group_size,
206 C,
207 HxW,
208 group,
209 eps,
210 BLOCK_GROUP_SIZE=triton.next_power_of_2(group_size),
211 BLOCK_HW_SIZE=triton.next_power_of_2(HxW),
212 )
213 return y, mean, rstd
216def group_norm_backward(
217 grad_out, input, mean, rstd, weight, N, C, HxW, group, output_mask
218):
219 logger.debug("GEMS GROUPNORM BACKWARD")
221 grad_out = grad_out.contiguous()
222 input = input.contiguous()
223 mean = mean.contiguous()
224 rstd = rstd.contiguous()
225 weight = None if weight is None else weight.contiguous()
226 group_size = triton.cdiv(C, group)
228 if output_mask[0]:
229 grad_inp = torch.empty_like(input)
230 grid = (N * group,)
231 with torch_device_fn.device(input.device):
232 group_norm_backward_kernel[grid](
233 grad_out,
234 input,
235 weight,
236 mean,
237 rstd,
238 group,
239 group_size,
240 grad_inp,
241 C,
242 HxW,
243 BLOCK_GROUP_SIZE=triton.next_power_of_2(group_size),
244 )
245 else:
246 grad_inp = None
248 if output_mask[1] is False and output_mask[2] is False:
249 return grad_inp, None, None
251 weight_grad = torch.empty_like(weight) if output_mask[1] else None
252 bias_grad = torch.empty_like(weight) if output_mask[2] else None
253 with torch_device_fn.device(input.device):
254 weight_bias_backward_kernel[(C, 1, 1)](
255 grad_out,
256 input,
257 mean,
258 rstd,
259 weight_grad,
260 bias_grad,
261 group,
262 group_size,
263 N,
264 C,
265 HxW,
266 BLOCK_N=triton.next_power_of_2(N),
267 BLOCK_HW=triton.next_power_of_2(HxW),
268 )
269 return grad_inp, weight_grad, bias_grad