Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/groupnorm.py: 0%
187 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
1import logging
2import os
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry, tl_extra_shim
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
13rsqrt = tl_extra_shim.rsqrt
16@libentry()
17@triton.jit(do_not_specialize=["eps"])
18def group_norm_kernel(
19 X,
20 Y,
21 W,
22 B,
23 Mean,
24 Rstd,
25 group_size,
26 C,
27 HW,
28 num_groups,
29 eps,
30 BLOCK_GROUP_SIZE: tl.constexpr,
31 BLOCK_HW_SIZE: tl.constexpr,
32):
33 pid = tle.program_id(0)
34 group = pid % num_groups
35 num_elements = group_size * HW
36 group_offset = tl.arange(0, BLOCK_GROUP_SIZE)
37 hw_offset = tl.arange(0, BLOCK_HW_SIZE)
39 wb_offset = group * group_size + group_offset
40 wb_mask = wb_offset < C
42 xy_offset = pid * num_elements + group_offset[:, None] * HW + hw_offset[None, :]
43 xy_mask = wb_offset[:, None] < C and hw_offset[None, :] < HW
45 Mean_ptr = Mean + pid
46 Rstd_ptr = Rstd + pid
48 X_ptr = X + xy_offset
49 Y_ptr = Y + xy_offset
51 X_val = tl.load(X_ptr, mask=xy_mask, other=0.0).to(tl.float32)
52 mean = tl.sum(X_val) / num_elements
53 x = tl.where(xy_mask, X_val - mean, 0.0)
55 var = tl.sum(x * x) / num_elements
56 rstd = rsqrt(var + eps)
57 x_hat = x * rstd
59 if W is None:
60 weight = 1
61 else:
62 weight = tl.load(W + wb_offset, mask=wb_mask, other=0.0)[:, None]
63 if B is None:
64 bias = 0
65 else:
66 bias = tl.load(B + wb_offset, mask=wb_mask, other=0.0)[:, None]
67 Y_val = x_hat * weight + bias
69 tl.store(Y_ptr, Y_val, mask=xy_mask)
70 tl.store(Mean_ptr, mean)
71 tl.store(Rstd_ptr, rstd)
74@libentry()
75@triton.jit
76def group_norm_backward_kernel(
77 grad_y,
78 X,
79 W,
80 Mean,
81 Rstd,
82 num_groups,
83 group_size,
84 grad_x,
85 C,
86 HW: tl.constexpr,
87 BLOCK_GROUP_SIZE: tl.constexpr,
88 BLOCK_HW_SIZE: tl.constexpr,
89):
90 pid = tle.program_id(0)
91 group = pid % num_groups
92 num_elements = group_size * HW
94 group_offset = tl.arange(0, BLOCK_GROUP_SIZE)
95 wb_offset = group * group_size + group_offset
97 wb_mask = wb_offset < C
99 rstd = tl.load(Rstd + pid).to(tl.float32)
100 mean = tl.load(Mean + pid).to(tl.float32)
101 if W is None:
102 weight = 1
103 else:
104 weight = tl.load(W + wb_offset, mask=wb_mask, other=0.0).to(tl.float32)[:, None]
106 dx_part2 = tl.zeros([BLOCK_GROUP_SIZE, BLOCK_HW_SIZE], dtype=tl.float32)
107 dx_part3 = tl.zeros([BLOCK_GROUP_SIZE, BLOCK_HW_SIZE], dtype=tl.float32)
108 for off in range(0, HW, BLOCK_HW_SIZE):
109 hw_offset = off + tl.arange(0, BLOCK_HW_SIZE)
110 hw_mask = hw_offset < HW
111 xy_offset = pid * num_elements + group_offset[:, None] * HW + hw_offset[None, :]
112 xy_mask = wb_mask[:, None] & hw_mask[None, :]
114 dY_val = tl.load(grad_y + xy_offset, mask=xy_mask, other=0.0).to(tl.float32)
115 X_val = tl.load(X + xy_offset, mask=xy_mask, other=0.0).to(tl.float32)
117 x_hat = tl.where(xy_mask, rstd * (X_val - mean), 0.0)
118 dx_hat = weight * dY_val
119 dx_part2 += dx_hat
120 dx_part3 += dx_hat * x_hat
122 dx_2 = tl.sum(dx_part2)
123 dx_3 = tl.sum(dx_part3)
125 for off in range(0, HW, BLOCK_HW_SIZE):
126 hw_offset = off + tl.arange(0, BLOCK_HW_SIZE)
127 hw_mask = hw_offset < HW
128 xy_offset = pid * num_elements + group_offset[:, None] * HW + hw_offset[None, :]
129 xy_mask = wb_mask[:, None] & hw_mask[None, :]
131 dY_val = tl.load(grad_y + xy_offset, mask=xy_mask, other=0.0).to(tl.float32)
132 X_val = tl.load(X + xy_offset, mask=xy_mask, other=0.0).to(tl.float32)
134 x_hat = tl.where(xy_mask, rstd * (X_val - mean), 0.0)
135 dx_hat = weight * dY_val
136 dx = rstd * (dx_hat - (dx_2 + x_hat * dx_3) / num_elements)
137 grad_x_offset = tl.where(xy_mask, xy_offset, -1)
139 tl.store(grad_x + grad_x_offset, dx, xy_mask)
142@libentry()
143@triton.jit
144def weight_bias_backward_kernel(
145 dY,
146 X,
147 Mean,
148 Rstd,
149 dW,
150 dB,
151 num_groups,
152 group_size,
153 N,
154 C,
155 HW,
156 BLOCK_N: tl.constexpr,
157 BLOCK_HW: tl.constexpr,
158):
159 pid = tle.program_id(0)
160 group = pid // group_size
161 n_offset = tl.arange(0, BLOCK_N)
162 hw_offset = tl.arange(0, BLOCK_HW)
163 xy_mask = n_offset[:, None] < N and hw_offset[None, :] < HW
164 mr_mask = n_offset < N
166 mean_ptr = Mean + group + n_offset * num_groups
167 rstd_ptr = Rstd + group + n_offset * num_groups
169 dY_ptr = dY + pid * HW + n_offset[:, None] * C * HW + hw_offset[None, :]
170 x_ptr = X + pid * HW + n_offset[:, None] * C * HW + hw_offset[None, :]
172 grad_y = tl.load(dY_ptr, mask=xy_mask, other=0.0).to(tl.float32)
173 x = tl.load(x_ptr, mask=xy_mask, other=0.0)
174 x_f32 = x.to(tl.float32)
175 mean = tl.load(mean_ptr, mask=mr_mask, other=0.0).to(tl.float32)[:, None]
176 rstd = tl.load(rstd_ptr, mask=mr_mask, other=0.0).to(tl.float32)[:, None]
178 if dW is not None:
179 dw = tl.sum((x_f32 - mean) * rstd * grad_y)
180 tl.store(dW + pid, dw.to(x.dtype))
181 if dB is not None:
182 db = tl.sum(grad_y)
183 tl.store(dB + pid, db.to(x.dtype))
186@libentry()
187@triton.jit
188def weight_bias_backward_kernel_loop(
189 dY,
190 X,
191 Mean,
192 Rstd,
193 dW,
194 dB,
195 num_groups,
196 group_size,
197 N,
198 C,
199 HW,
200 BLOCK_N: tl.constexpr,
201 BLOCK_HW: tl.constexpr,
202):
203 pid = tle.program_id(0)
204 group = pid // group_size
206 grad_y_tile = tl.zeros((BLOCK_N, BLOCK_HW), dtype=tl.float32) # grad_y_tile
207 dw_tile = tl.zeros((BLOCK_N, BLOCK_HW), dtype=tl.float32)
208 for start_n in range(0, N, BLOCK_N):
209 n_offset = start_n + tl.arange(0, BLOCK_N)
211 mean_ptr = Mean + group + n_offset * num_groups
212 rstd_ptr = Rstd + group + n_offset * num_groups
213 mr_mask = n_offset < N
214 mean = tl.load(mean_ptr, mask=mr_mask, other=0.0).to(tl.float32)[:, None]
215 rstd = tl.load(rstd_ptr, mask=mr_mask, other=0.0).to(tl.float32)[:, None]
217 for start_hw in range(0, HW, BLOCK_HW):
218 hw_offset = start_hw + tl.arange(0, BLOCK_HW)
219 xy_mask = n_offset[:, None] < N and hw_offset[None, :] < HW
220 dY_ptr = dY + pid * HW + n_offset[:, None] * C * HW + hw_offset[None, :]
221 grad_y = tl.load(dY_ptr, mask=xy_mask, other=0.0).to(tl.float32)
222 grad_y_tile += grad_y
224 x_ptr = X + pid * HW + n_offset[:, None] * C * HW + hw_offset[None, :]
225 x = tl.load(x_ptr, mask=xy_mask, other=0.0)
226 x_f32 = x.to(tl.float32)
227 dw_tile += (x_f32 - mean) * rstd * grad_y
229 dw = tl.sum(dw_tile)
230 db = tl.sum(grad_y_tile)
231 tl.store(dW + pid, dw)
232 tl.store(dB + pid, db)
235def group_norm(input, weight, bias, N, C, HxW, group, eps=1e-05):
236 logger.debug("GEMS GROUPNORM FORWARD")
238 group_size = triton.cdiv(C, group)
239 input = input.contiguous()
240 weight = None if weight is None else weight.contiguous()
241 bias = None if bias is None else bias.contiguous()
243 y = torch.empty_like(input)
244 mean = torch.empty((N, group), dtype=input.dtype, device=input.device)
245 rstd = torch.empty((N, group), dtype=input.dtype, device=input.device)
247 grid = (N * group,)
248 with torch_device_fn.device(input.device):
249 if N == 1 and C == 64 and HxW == 1024 and group == 64:
250 os.environ["TRITONXPU_OTHER_SIM"] = "1"
251 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1"
252 group_norm_kernel[grid](
253 input,
254 y,
255 weight,
256 bias,
257 mean,
258 rstd,
259 group_size,
260 C,
261 HxW,
262 group,
263 eps,
264 BLOCK_GROUP_SIZE=triton.next_power_of_2(group_size),
265 BLOCK_HW_SIZE=triton.next_power_of_2(HxW),
266 )
267 if "TRITONXPU_OTHER_SIM" in os.environ:
268 del os.environ["TRITONXPU_OTHER_SIM"]
269 if "TRITONXPU_STORE_MASK_SIM" in os.environ:
270 del os.environ["TRITONXPU_STORE_MASK_SIM"]
272 return y, mean, rstd
275def group_norm_backward(
276 grad_out, input, mean, rstd, weight, N, C, HxW, group, output_mask
277):
278 logger.debug("GEMS GROUPNORM BACKWARD")
280 grad_out = grad_out.contiguous()
281 input = input.contiguous()
282 mean = mean.contiguous()
283 rstd = rstd.contiguous()
284 weight = None if weight is None else weight.contiguous()
285 group_size = triton.cdiv(C, group)
287 if output_mask[0]:
288 grad_inp = torch.empty_like(input)
289 grid = (N * group,)
290 with torch_device_fn.device(input.device):
291 import os
293 os.environ["TRITONXPU_OTHER_SIM"] = "1"
294 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1"
295 group_norm_backward_kernel[grid](
296 grad_out,
297 input,
298 weight,
299 mean,
300 rstd,
301 group,
302 group_size,
303 grad_inp,
304 C,
305 HxW,
306 BLOCK_GROUP_SIZE=triton.next_power_of_2(group_size),
307 BLOCK_HW_SIZE=triton.next_power_of_2(HxW),
308 isCloseUnrollControl=True,
309 )
310 if "TRITONXPU_OTHER_SIM" in os.environ:
311 del os.environ["TRITONXPU_OTHER_SIM"]
312 if "TRITONXPU_STORE_MASK_SIM" in os.environ:
313 del os.environ["TRITONXPU_STORE_MASK_SIM"]
315 else:
316 grad_inp = None
318 if output_mask[1] is False and output_mask[2] is False:
319 return grad_inp, None, None
321 weight_grad = torch.empty_like(weight) if output_mask[1] else None
322 bias_grad = torch.empty_like(weight) if output_mask[2] else None
323 with torch_device_fn.device(input.device):
324 if N == 32 and C == 32 and HxW == 1024 and group == 8:
325 weight_bias_backward_kernel_loop[(C, 1, 1)](
326 grad_out,
327 input,
328 mean,
329 rstd,
330 weight_grad,
331 bias_grad,
332 group,
333 group_size,
334 N,
335 C,
336 HxW,
337 BLOCK_N=1,
338 BLOCK_HW=triton.next_power_of_2(HxW),
339 isCloseUnrollControl=True,
340 isCloseCoreTiling=True,
341 )
342 else:
343 if output_mask[1] is True and output_mask[2] is True:
344 isCloseUnrollControl = True
345 weight_bias_backward_kernel[(C, 1, 1)](
346 grad_out,
347 input,
348 mean,
349 rstd,
350 weight_grad,
351 bias_grad,
352 group,
353 group_size,
354 N,
355 C,
356 HxW,
357 BLOCK_N=triton.next_power_of_2(N),
358 BLOCK_HW=triton.next_power_of_2(HxW),
359 isCloseUnrollControl=isCloseUnrollControl,
360 )
361 return grad_inp, weight_grad, bias_grad