Coverage for src/flag_gems/runtime/backend/_ascend/ops/groupnorm.py: 0%
168 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import libentry, tl_extra_shim
9from flag_gems.utils import triton_lang_extension as tle
11logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
13rsqrt = tl_extra_shim.rsqrt
16@libentry()
17@triton.jit
18def group_norm_backward_kernel(
19 grad_y,
20 X,
21 W,
22 Mean,
23 Rstd,
24 num_groups,
25 group_size,
26 grad_x,
27 C,
28 HW,
29 BLOCK_GROUP_SIZE: tl.constexpr,
30 BLOCK_HW_SIZE: tl.constexpr = 128,
31):
32 pid = tle.program_id(0)
33 group = pid % num_groups
34 num_elements = group_size * HW
36 group_offset = tl.arange(0, BLOCK_GROUP_SIZE)
37 wb_offset = group * group_size + group_offset
39 wb_mask = wb_offset < C
40 rstd = tl.load(Rstd + pid).to(tl.float32)
41 mean = tl.load(Mean + pid).to(tl.float32)
43 if W is None:
44 weight = 1
45 else:
46 weight = tl.load(W + wb_offset, mask=wb_mask, other=0.0).to(tl.float32)[:, None]
48 dx_part2 = tl.zeros([BLOCK_GROUP_SIZE, BLOCK_HW_SIZE], dtype=tl.float32)
49 dx_part3 = tl.zeros([BLOCK_GROUP_SIZE, BLOCK_HW_SIZE], dtype=tl.float32)
50 for off in range(0, HW, BLOCK_HW_SIZE):
51 hw_offset = off + tl.arange(0, BLOCK_HW_SIZE)
52 hw_mask = hw_offset < HW
53 xy_offset = pid * num_elements + group_offset[:, None] * HW + hw_offset[None, :]
54 xy_mask = wb_mask[:, None] & hw_mask[None, :]
55 dY_val = tl.load(grad_y + xy_offset, mask=xy_mask, other=0.0).to(tl.float32)
56 X_val = tl.load(X + xy_offset, mask=xy_mask, other=0.0).to(tl.float32)
58 x_hat = tl.where(xy_mask, rstd * (X_val - mean), 0.0)
59 dx_hat = weight * dY_val
60 dx_part2 += dx_hat
61 dx_part3 += dx_hat * x_hat
63 dx_2 = tl.sum(dx_part2)
64 dx_3 = tl.sum(dx_part3)
66 for off in range(0, HW, BLOCK_HW_SIZE):
67 hw_offset = off + tl.arange(0, BLOCK_HW_SIZE)
68 hw_mask = hw_offset < HW
69 xy_offset = pid * num_elements + group_offset[:, None] * HW + hw_offset[None, :]
70 xy_mask = wb_mask[:, None] & hw_mask[None, :]
72 dY_val = tl.load(grad_y + xy_offset, mask=xy_mask, other=0.0).to(tl.float32)
73 X_val = tl.load(X + xy_offset, mask=xy_mask, other=0.0).to(tl.float32)
75 x_hat = tl.where(xy_mask, rstd * (X_val - mean), 0.0)
76 dx_hat = weight * dY_val
77 dx = rstd * (dx_hat - (dx_2 + x_hat * dx_3) / num_elements)
79 tl.store(grad_x + xy_offset, dx, xy_mask)
82@libentry()
83@triton.jit
84def weight_bias_backward_kernel(
85 dY,
86 X,
87 Mean,
88 Rstd,
89 dW,
90 dB,
91 num_groups,
92 group_size,
93 N,
94 C,
95 HW,
96 BLOCK_N: tl.constexpr,
97 BLOCK_HW: tl.constexpr,
98):
99 pid = tle.program_id(0)
100 group = pid // group_size
101 n_offset = tl.arange(0, BLOCK_N)
102 mr_mask = n_offset < N
103 mean_ptr = Mean + group + n_offset * num_groups
104 rstd_ptr = Rstd + group + n_offset * num_groups
105 mean = tl.load(mean_ptr, mask=mr_mask, other=0.0).to(tl.float32)[:, None]
106 rstd = tl.load(rstd_ptr, mask=mr_mask, other=0.0).to(tl.float32)[:, None]
108 SUB_BLOCK_HW: tl.constexpr = 64
110 dw_sum = 0.0
111 db_sum = 0.0
113 for hw_off in range(0, BLOCK_HW, SUB_BLOCK_HW):
114 hw_offset = hw_off + tl.arange(0, SUB_BLOCK_HW)
115 xy_mask = n_offset[:, None] < N and hw_offset[None, :] < HW
116 dY_ptr = dY + pid * HW + n_offset[:, None] * C * HW + hw_offset[None, :]
117 x_ptr = X + pid * HW + n_offset[:, None] * C * HW + hw_offset[None, :]
118 grad_y = tl.load(dY_ptr, mask=xy_mask, other=0.0).to(tl.float32)
119 x = tl.load(x_ptr, mask=xy_mask, other=0.0)
120 x_f32 = x.to(tl.float32)
121 if dW is not None:
122 dw_sum = dw_sum + tl.sum((x_f32 - mean) * rstd * grad_y)
123 if dB is not None:
124 db_sum = db_sum + tl.sum(grad_y)
126 if dW is not None:
127 dw = dw_sum
128 tl.store(dW + pid, dw.to(mean.dtype))
129 if dB is not None:
130 db = db_sum
131 tl.store(dB + pid, db.to(mean.dtype))
134@libentry()
135@triton.autotune(
136 configs=[
137 triton.Config({"BLOCK_SUB_HW_SIZE": 32}),
138 triton.Config({"BLOCK_SUB_HW_SIZE": 64}),
139 triton.Config({"BLOCK_SUB_HW_SIZE": 128}),
140 triton.Config({"BLOCK_SUB_HW_SIZE": 256}),
141 triton.Config({"BLOCK_SUB_HW_SIZE": 512}),
142 triton.Config({"BLOCK_SUB_HW_SIZE": 1024}),
143 triton.Config({"BLOCK_SUB_HW_SIZE": 2048}),
144 triton.Config({"BLOCK_SUB_HW_SIZE": 4096}),
145 triton.Config({"BLOCK_SUB_HW_SIZE": 8192}),
146 triton.Config({"BLOCK_SUB_HW_SIZE": 16384}),
147 ],
148 key=["HW", "group_size"],
149)
150@triton.jit(do_not_specialize=["eps"])
151def group_norm_kernel(
152 X,
153 Y,
154 W,
155 B,
156 Mean,
157 Rstd,
158 group_size,
159 C,
160 HW,
161 num_groups,
162 eps,
163 BLOCK_GROUP_SIZE: tl.constexpr,
164 BLOCK_HW_SIZE: tl.constexpr,
165 BLOCK_SUB_HW_SIZE: tl.constexpr,
166):
167 pid = tl.program_id(0)
168 batch_idx = pid // num_groups
169 group_idx = pid % num_groups
171 # 计算当前group在整个tensor中的起始位置
172 batch_offset = batch_idx * C * HW
173 group_start_channel = group_idx * group_size
175 num_elements = group_size * HW
177 # 第一次遍历:计算均值
178 X_sum = 0.0
179 for hw_start in range(0, HW, BLOCK_SUB_HW_SIZE):
180 hw_offsets = hw_start + tl.arange(0, BLOCK_SUB_HW_SIZE)
181 hw_mask = hw_offsets < HW
183 # 先按HW维度连续,再按channel维度
184 for c_idx in range(BLOCK_GROUP_SIZE):
185 if c_idx < group_size and (group_start_channel + c_idx) < C:
186 channel_offset = group_start_channel + c_idx
187 # 连续访问HW维度的数据
188 base_offset = batch_offset + channel_offset * HW + hw_offsets
189 X_vals = tl.load(X + base_offset, mask=hw_mask, other=0.0).to(
190 tl.float32
191 )
192 X_sum += tl.sum(X_vals)
194 mean = X_sum / num_elements
196 # 第二次遍历:计算方差
197 X_var_sum = 0.0
198 for hw_start in range(0, HW, BLOCK_SUB_HW_SIZE):
199 hw_offsets = hw_start + tl.arange(0, BLOCK_SUB_HW_SIZE)
200 hw_mask = hw_offsets < HW
202 for c_idx in range(BLOCK_GROUP_SIZE):
203 if c_idx < group_size and (group_start_channel + c_idx) < C:
204 channel_offset = group_start_channel + c_idx
205 base_offset = batch_offset + channel_offset * HW + hw_offsets
206 X_vals = tl.load(X + base_offset, mask=hw_mask, other=mean).to(
207 tl.float32
208 )
209 x_centered = X_vals - mean
210 X_var_sum += tl.sum(x_centered * x_centered)
212 var = X_var_sum / num_elements
213 rstd = rsqrt(var + eps)
215 # 第三次遍历:归一化并写回
216 for hw_start in range(0, HW, BLOCK_SUB_HW_SIZE):
217 hw_offsets = hw_start + tl.arange(0, BLOCK_SUB_HW_SIZE)
218 hw_mask = hw_offsets < HW
220 for c_idx in range(BLOCK_GROUP_SIZE):
221 if c_idx < group_size and (group_start_channel + c_idx) < C:
222 channel_offset = group_start_channel + c_idx
223 base_offset = batch_offset + channel_offset * HW + hw_offsets
225 # 加载数据
226 X_vals = tl.load(X + base_offset, mask=hw_mask, other=0.0).to(
227 tl.float32
228 )
230 # 归一化并应用仿射变换
231 x_normalized = (X_vals - mean) * rstd
232 if W is not None:
233 w_val = tl.load(W + channel_offset)
234 x_normalized = x_normalized * w_val
235 if B is not None:
236 b_val = tl.load(B + channel_offset)
237 x_normalized = x_normalized + b_val
239 # 存储结果
240 tl.store(Y + base_offset, x_normalized, mask=hw_mask)
242 # 存储均值和标准差
243 mean_rstd_offset = batch_idx * num_groups + group_idx
244 tl.store(Mean + mean_rstd_offset, mean)
245 tl.store(Rstd + mean_rstd_offset, rstd)
248def group_norm(input, weight, bias, N, C, HxW, group, eps=1e-05):
249 logger.debug("ASCEND GEMS GROUPNORM FORWARD")
250 group_size = triton.cdiv(C, group)
251 input = input.contiguous()
252 weight = None if weight is None else weight.contiguous()
253 bias = None if bias is None else bias.contiguous()
255 y = torch.empty_like(input)
256 mean = torch.empty((N, group), dtype=input.dtype, device=input.device)
257 rstd = torch.empty((N, group), dtype=input.dtype, device=input.device)
259 grid = (N * group,)
260 with torch_device_fn.device(input.device):
261 group_norm_kernel[grid](
262 input,
263 y,
264 weight,
265 bias,
266 mean,
267 rstd,
268 group_size,
269 C,
270 HxW,
271 group,
272 eps,
273 BLOCK_GROUP_SIZE=triton.next_power_of_2(group_size),
274 BLOCK_HW_SIZE=triton.next_power_of_2(HxW),
275 )
276 return y, mean, rstd
279def group_norm_backward(
280 grad_out, input, mean, rstd, weight, N, C, HxW, group, output_mask
281):
282 logger.debug("ASCEND GEMS GROUPNORM BACKWARD")
283 grad_out = grad_out.contiguous()
284 input = input.contiguous()
285 mean = mean.contiguous()
286 rstd = rstd.contiguous()
287 weight = None if weight is None else weight.contiguous()
288 group_size = triton.cdiv(C, group)
290 if output_mask[0]:
291 grad_inp = torch.empty_like(input)
292 grid = (N * group,)
293 with torch_device_fn.device(input.device):
294 group_norm_backward_kernel[grid](
295 grad_out,
296 input,
297 weight,
298 mean,
299 rstd,
300 group,
301 group_size,
302 grad_inp,
303 C,
304 HxW,
305 BLOCK_GROUP_SIZE=triton.next_power_of_2(group_size),
306 )
307 else:
308 grad_inp = None
310 if output_mask[1] is False and output_mask[2] is False:
311 return grad_inp, None, None
313 weight_grad = torch.empty_like(weight) if output_mask[1] else None
314 bias_grad = torch.empty_like(weight) if output_mask[2] else None
315 with torch_device_fn.device(input.device):
316 weight_bias_backward_kernel[(C, 1, 1)](
317 grad_out,
318 input,
319 mean,
320 rstd,
321 weight_grad,
322 bias_grad,
323 group,
324 group_size,
325 N,
326 C,
327 HxW,
328 BLOCK_N=triton.next_power_of_2(N),
329 BLOCK_HW=triton.next_power_of_2(HxW),
330 )
331 return grad_inp, weight_grad, bias_grad