Coverage for src/flag_gems/runtime/backend/_cambricon/ops/weightnorm.py: 0%
228 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1import copy
2import logging
3import math
5import torch
6import triton
7import triton.language as tl
9from flag_gems import runtime
10from flag_gems.runtime import torch_device_fn
11from flag_gems.utils import libentry
13from ..utils import MAX_NRAM_SIZE, TOTAL_CORE_NUM
15logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
16MAX_N = 31744
19@libentry()
20@triton.autotune(
21 configs=runtime.get_tuned_config("weight_norm_kernel_last"), key=["M", "N"]
22)
23@triton.jit(do_not_specialize=["eps"])
24def weight_norm_kernel_last(
25 output,
26 norm,
27 v,
28 g,
29 M,
30 N,
31 eps,
32 BLOCK_ROW_SIZE: tl.constexpr,
33 BLOCK_COL_SIZE: tl.constexpr,
34):
35 tx = tl.arange(0, BLOCK_COL_SIZE)[:, None]
36 bx = tl.program_id(axis=0) * BLOCK_COL_SIZE
37 col_offset = bx + tx
38 col_mask = col_offset < N
40 ty = tl.arange(0, BLOCK_ROW_SIZE)[None, :]
41 v_block = tl.zeros([BLOCK_COL_SIZE, BLOCK_ROW_SIZE], dtype=tl.float32)
42 for base in range(0, M, BLOCK_ROW_SIZE):
43 row_offset = base + ty
44 mask = row_offset < M and col_mask
45 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32)
46 v_block += v_value * v_value
48 normalized = tl.sqrt(tl.sum(v_block, axis=1) + eps)
49 tl.store(norm + col_offset, normalized[:, None], mask=col_mask)
50 g_value = tl.load(g + col_offset, mask=col_mask).to(tl.float32)
52 for base in range(0, M, BLOCK_ROW_SIZE):
53 row_offset = base + ty
54 mask = row_offset < M and col_mask
55 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32)
56 v_vec = v_value / normalized[:, None]
57 out = v_vec * g_value
58 tl.store(output + row_offset * N + col_offset, out, mask=mask)
61def config_prune_for_first(configs, named_args, **kwargs):
62 M = named_args["M"]
63 N = named_args["N"]
64 configs_map = {}
65 # When N is less than MAX_C_MLU_SOFTMAX_FORWARD, no reduction loops
66 for config in configs:
67 kw = config.kwargs
68 BLOCK_ROW_SIZE, BLOCK_COL_SIZE, num_warps, num_stages = (
69 kw["BLOCK_ROW_SIZE"],
70 kw["BLOCK_COL_SIZE"],
71 config.num_warps,
72 config.num_stages,
73 )
74 if N < MAX_N:
75 config = copy.deepcopy(config)
76 BLOCK_COL_SIZE = config.kwargs["BLOCK_COL_SIZE"] = N
77 m_per_core = math.ceil(M / TOTAL_CORE_NUM)
78 nram_usage = (3 * BLOCK_COL_SIZE + 1) * m_per_core * 4
79 if nram_usage < MAX_NRAM_SIZE:
80 BLOCK_ROW_SIZE = config.kwargs["BLOCK_ROW_SIZE"] = m_per_core
81 num_stages = config.num_stages = 1
82 key = (BLOCK_ROW_SIZE, BLOCK_COL_SIZE, num_warps, num_stages)
83 configs_map.setdefault(key, config)
84 else:
85 max_block_m_without_pipe = (
86 MAX_NRAM_SIZE // 4 // (3 * BLOCK_COL_SIZE + 1)
87 )
88 BLOCK_ROW_SIZE = config.kwargs[
89 "BLOCK_ROW_SIZE"
90 ] = max_block_m_without_pipe
91 num_stages = config.num_stages = 1
92 key = (BLOCK_ROW_SIZE, BLOCK_COL_SIZE, num_warps, num_stages)
93 configs_map.setdefault(key, config)
95 config = copy.deepcopy(config)
96 max_block_m_without_pipe = (
97 MAX_NRAM_SIZE // 4 // (6 * BLOCK_COL_SIZE + 4)
98 )
99 num_stages = config.num_stages = 3
100 key = (BLOCK_ROW_SIZE, BLOCK_COL_SIZE, num_warps, num_stages)
101 configs_map.setdefault(key, config)
102 key = (BLOCK_ROW_SIZE, BLOCK_COL_SIZE, num_warps, num_stages)
103 # Only keep one config for the same key
104 configs_map.setdefault(key, config)
105 pruned_configs = []
106 for k, v in configs_map.items():
107 pruned_configs.append(v)
108 return pruned_configs
111def tile_mode_for_first(args):
112 one_tile_m = args["BLOCK_ROW_SIZE"] * TOTAL_CORE_NUM >= args["M"]
113 one_tile_n = args["BLOCK_COL_SIZE"] >= args["N"]
114 if one_tile_n and one_tile_m:
115 return 0
116 elif one_tile_n and not one_tile_m:
117 return 1
118 else:
119 return 2
122@libentry()
123@triton.autotune(
124 configs=runtime.get_tuned_config("weight_norm_kernel_first"),
125 key=["M", "N"],
126 prune_configs_by={"early_config_prune": config_prune_for_first},
127)
128@triton.heuristics(
129 values={
130 "TILE_MODE": lambda args: tile_mode_for_first(args),
131 },
132)
133@triton.jit(do_not_specialize=["eps"])
134def weight_norm_kernel_first(
135 output,
136 norm,
137 v,
138 g,
139 M,
140 N,
141 eps,
142 BLOCK_ROW_SIZE: tl.constexpr,
143 BLOCK_COL_SIZE: tl.constexpr,
144 TILE_MODE: tl.constexpr,
145):
146 pid_m = tl.program_id(0)
147 pnum = tl.num_programs(axis=0)
148 split_m = tl.cdiv(M, pnum)
149 m_start = pid_m * split_m
150 if TILE_MODE == 0:
151 m_offset = pid_m * BLOCK_ROW_SIZE + tl.arange(0, BLOCK_ROW_SIZE)
152 n_offset = tl.arange(0, BLOCK_COL_SIZE)
153 offset = m_offset[:, None] * N + n_offset[None, :]
154 mask = m_offset[:, None] < M
155 v_value = tl.load(v + offset, mask=mask).to(tl.float32)
156 normalized = tl.sqrt(tl.sum(v_value * v_value, axis=1) + eps)
157 tl.store(norm + m_offset[:, None], normalized[:, None], mask=mask)
158 g_value = tl.load(g + m_offset[:, None], mask=mask).to(tl.float32)
159 v_vec = v_value / normalized[:, None]
160 out = v_vec * g_value
161 tl.store(output + offset, out, mask=mask)
162 elif TILE_MODE == 1:
163 for m_idx in range(0, split_m, BLOCK_ROW_SIZE):
164 m_offset = m_start + m_idx + tl.arange(0, BLOCK_ROW_SIZE)
165 n_offset = tl.arange(0, BLOCK_COL_SIZE)
166 offset = m_offset[:, None] * N + n_offset[None, :]
167 mask = m_offset[:, None] < M
168 v_value = tl.load(v + offset, mask=mask).to(tl.float32)
169 normalized = tl.sqrt(tl.sum(v_value * v_value, axis=1) + eps)
170 tl.store(norm + m_offset[:, None], normalized[:, None], mask=mask)
171 g_value = tl.load(g + m_offset[:, None], mask=mask).to(tl.float32)
172 v_vec = v_value / normalized[:, None]
173 out = v_vec * g_value
174 tl.store(output + offset, out, mask=mask)
175 else:
176 for m_idx in range(0, split_m, BLOCK_ROW_SIZE):
177 m_offset = m_start + m_idx + tl.arange(0, BLOCK_ROW_SIZE)
178 m_mask = m_offset[:, None] < M
179 v_block = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
180 for start_n in range(0, N, BLOCK_COL_SIZE):
181 n_offset = start_n + tl.arange(0, BLOCK_COL_SIZE)
182 offset = m_offset[:, None] * N + n_offset[None, :]
183 mask = m_mask and n_offset[None, :] < N
184 v_value = tl.load(v + offset, mask=mask).to(tl.float32)
185 v_block += v_value * v_value
187 normalized = tl.sqrt(tl.sum(v_block, axis=1) + eps)
188 tl.store(norm + m_offset[:, None], normalized[:, None], mask=m_mask)
189 g_value = tl.load(g + m_offset[:, None], mask=m_mask).to(tl.float32)
191 for start_n in range(0, N, BLOCK_COL_SIZE):
192 n_offset = start_n + tl.arange(0, BLOCK_COL_SIZE)
193 offset = m_offset[:, None] * N + n_offset[None, :]
194 mask = m_mask and n_offset[None, :] < N
195 v_value = tl.load(v + offset, mask=mask).to(tl.float32)
196 v_vec = v_value / normalized[:, None]
197 out = v_vec * g_value
198 tl.store(output + offset, out, mask=mask)
201@libentry()
202@triton.autotune(
203 configs=runtime.get_tuned_config("weight_norm_kernel_last"), key=["M", "N"]
204)
205@triton.jit(do_not_specialize=["eps"])
206def weight_norm_bwd_kernel_last(
207 v_grad,
208 g_grad,
209 w,
210 v,
211 g,
212 norm,
213 M,
214 N,
215 eps,
216 BLOCK_ROW_SIZE: tl.constexpr,
217 BLOCK_COL_SIZE: tl.constexpr,
218):
219 tx = tl.arange(0, BLOCK_COL_SIZE)[:, None]
220 bx = tl.program_id(axis=0) * BLOCK_COL_SIZE
221 col_offset = tx + bx
222 col_mask = col_offset < N
224 g_value = tl.load(g + col_offset, mask=col_mask).to(tl.float32)
225 norm_value = tl.load(norm + col_offset, mask=col_mask).to(tl.float32)
227 ty = tl.arange(0, BLOCK_ROW_SIZE)[None, :]
229 vw_block = tl.zeros([BLOCK_COL_SIZE, BLOCK_ROW_SIZE], dtype=tl.float32)
230 for base in range(0, M, BLOCK_ROW_SIZE):
231 row_offset = base + ty
232 mask = row_offset < M and col_mask
233 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32)
234 w_value = tl.load(w + row_offset * N + col_offset, mask=mask).to(tl.float32)
235 vw_block += v_value * w_value
236 vw_sum = tl.sum(vw_block, 1)[:, None]
238 for base in range(0, M, BLOCK_ROW_SIZE):
239 row_offset = base + ty
240 mask = row_offset < M and col_mask
241 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32)
242 w_value = tl.load(w + row_offset * N + col_offset, mask=mask).to(tl.float32)
243 v_grad_value = g_value * (
244 w_value / (norm_value + eps)
245 - v_value / (norm_value * norm_value * norm_value + eps) * vw_sum
246 )
247 tl.store(v_grad + row_offset * N + col_offset, v_grad_value, mask=mask)
249 g_grad_value = vw_sum / (norm_value + eps)
250 tl.store(g_grad + col_offset, g_grad_value, mask=col_mask)
253@libentry()
254@triton.autotune(
255 configs=runtime.get_tuned_config("weight_norm_kernel_first"), key=["M", "N"]
256)
257@triton.jit(do_not_specialize=["eps"])
258def weight_norm_bwd_kernel_first(
259 v_grad,
260 g_grad,
261 w,
262 v,
263 g,
264 norm,
265 M,
266 N,
267 eps,
268 BLOCK_ROW_SIZE: tl.constexpr,
269 BLOCK_COL_SIZE: tl.constexpr,
270):
271 ty = tl.arange(0, BLOCK_ROW_SIZE)[:, None]
272 by = tl.program_id(axis=0) * BLOCK_ROW_SIZE
273 row_offset = by + ty
274 row_mask = row_offset < M
276 g_value = tl.load(g + row_offset, mask=row_mask).to(tl.float32)
277 norm_value = tl.load(norm + row_offset, mask=row_mask).to(tl.float32)
279 tx = tl.arange(0, BLOCK_COL_SIZE)[None, :]
281 v_block = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
282 for base in range(0, N, BLOCK_COL_SIZE):
283 col_offset = base + tx
284 mask = col_offset < N and row_mask
285 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32)
286 w_value = tl.load(w + row_offset * N + col_offset, mask=mask).to(tl.float32)
287 v_block += v_value * w_value
288 vw_sum = tl.sum(v_block, 1)[:, None]
290 for base in range(0, N, BLOCK_COL_SIZE):
291 col_offset = base + tx
292 mask = col_offset < N and row_mask
293 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32)
294 w_value = tl.load(w + row_offset * N + col_offset, mask=mask).to(tl.float32)
295 v_grad_value = g_value * (
296 w_value / (norm_value + eps)
297 - v_value / (norm_value * norm_value * norm_value + eps) * vw_sum
298 )
299 tl.store(v_grad + row_offset * N + col_offset, v_grad_value, mask=mask)
301 g_grad_value = vw_sum / (norm_value + eps)
302 tl.store(g_grad + row_offset, g_grad_value, mask=row_mask)
305def weight_norm_interface(v, g, dim=0):
306 logger.debug("GEMS_CAMBRICON WEIGHTNORM FORWARD")
307 v = v.contiguous()
308 g = g.contiguous()
309 output = torch.empty_like(v)
310 norm = torch.empty_like(g)
311 if dim == 0:
312 M = v.shape[0]
313 N = math.prod(v.shape[1:])
314 with torch_device_fn.device(v.device):
315 weight_norm_kernel_first[TOTAL_CORE_NUM, 1, 1](
316 output, norm, v, g, M, N, eps=torch.finfo(torch.float32).tiny
317 )
318 elif dim == v.ndim - 1:
319 M = math.prod(v.shape[:-1])
320 N = v.shape[dim]
321 grid = lambda META: (triton.cdiv(N, META["BLOCK_COL_SIZE"]),)
322 with torch_device_fn.device(v.device):
323 weight_norm_kernel_last[grid](
324 output, norm, v, g, M, N, eps=torch.finfo(torch.float32).tiny
325 )
326 return output, norm
329def weight_norm_interface_backward(w_grad, saved_v, saved_g, saved_norms, dim):
330 logger.debug("GEMS_CAMBRICON WEIGHTNORM BACKWARD")
331 w_grad = w_grad.contiguous()
332 saved_v = saved_v.contiguous()
333 saved_g = saved_g.contiguous()
334 saved_norms = saved_norms.contiguous()
335 v_grad = torch.empty_like(saved_v)
336 g_grad = torch.empty_like(saved_g)
338 if dim == 0:
339 M = saved_v.shape[0]
340 N = math.prod(saved_v.shape[1:])
341 grid = lambda META: (triton.cdiv(M, META["BLOCK_ROW_SIZE"]),)
342 with torch_device_fn.device(saved_v.device):
343 weight_norm_bwd_kernel_first[grid](
344 v_grad,
345 g_grad,
346 w_grad,
347 saved_v,
348 saved_g,
349 saved_norms,
350 M,
351 N,
352 eps=torch.finfo(torch.float32).tiny,
353 )
354 elif dim == saved_v.ndim - 1:
355 M = math.prod(saved_v.shape[:dim])
356 N = saved_v.shape[dim]
357 grid = lambda META: (triton.cdiv(N, META["BLOCK_COL_SIZE"]),)
358 with torch_device_fn.device(saved_v.device):
359 weight_norm_bwd_kernel_last[grid](
360 v_grad,
361 g_grad,
362 w_grad,
363 saved_v,
364 saved_g,
365 saved_norms,
366 M,
367 N,
368 eps=torch.finfo(torch.float32).tiny,
369 )
370 return v_grad, g_grad