Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/weightnorm.py: 0%
171 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8# from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as tle
13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
16def weight_norm_kernel_last_block_row(args):
17 return 1
18 import builtins
20 return builtins.min(args["M"], 8192)
23def weight_norm_kernel_last_block_col(args):
24 # return 1
25 return triton.next_power_of_2(triton.cdiv(args["N"], 12))
28@libentry()
29# @triton.autotune(
30# configs=runtime.get_tuned_config("weight_norm_kernel_last"), key=["M", "N"]
31# )
32@triton.heuristics(
33 values={
34 "BLOCK_ROW_SIZE": weight_norm_kernel_last_block_row,
35 "BLOCK_COL_SIZE": weight_norm_kernel_last_block_col,
36 },
37)
38@triton.jit(do_not_specialize=["eps"])
39def weight_norm_kernel_last(
40 output,
41 norm,
42 v,
43 g,
44 M: tl.constexpr,
45 N: tl.constexpr,
46 eps,
47 BLOCK_ROW_SIZE: tl.constexpr,
48 BLOCK_COL_SIZE: tl.constexpr,
49):
50 tx = tl.arange(0, BLOCK_COL_SIZE)[:, None]
51 bx = tle.program_id(axis=0) * BLOCK_COL_SIZE
52 col_offset = bx + tx
53 col_mask = col_offset < N
55 ty = tl.arange(0, BLOCK_ROW_SIZE)[None, :]
56 v_block = tl.zeros([BLOCK_COL_SIZE, BLOCK_ROW_SIZE], dtype=tl.float32)
57 for base in range(0, M, BLOCK_ROW_SIZE):
58 row_offset = base + ty
59 mask = row_offset < M and col_mask
60 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32)
61 v_block += v_value * v_value
63 normalized = tl.sqrt(tl.sum(v_block, axis=1) + eps)
64 tl.store(norm + col_offset, normalized[:, None], mask=col_mask)
65 g_value = tl.load(g + col_offset, mask=col_mask).to(tl.float32)
67 for base in range(0, M, BLOCK_ROW_SIZE):
68 row_offset = base + ty
69 mask = row_offset < M and col_mask
70 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32)
71 v_vec = v_value / normalized[:, None]
72 out = v_vec * g_value
73 tl.store(output + row_offset * N + col_offset, out, mask=mask)
76def weight_norm_kernel_first_block_row(args):
77 return triton.next_power_of_2(triton.cdiv(args["M"], 12))
80def weight_norm_kernel_first_block_col(args):
81 return 1
84@libentry()
85# @triton.autotune(
86# configs=runtime.get_tuned_config("weight_norm_kernel_first"), key=["M", "N"]
87# )
88@triton.heuristics(
89 values={
90 "BLOCK_ROW_SIZE": weight_norm_kernel_first_block_row,
91 "BLOCK_COL_SIZE": weight_norm_kernel_first_block_col,
92 },
93)
94@triton.jit(do_not_specialize=["eps"])
95def weight_norm_kernel_first(
96 output,
97 norm,
98 v,
99 g,
100 M,
101 N,
102 eps,
103 BLOCK_ROW_SIZE: tl.constexpr,
104 BLOCK_COL_SIZE: tl.constexpr,
105):
106 ty = tl.arange(0, BLOCK_ROW_SIZE)[:, None]
107 by = tle.program_id(axis=0) * BLOCK_ROW_SIZE
108 row_offset = by + ty
109 row_mask = row_offset < M
111 tx = tl.arange(0, BLOCK_COL_SIZE)[None, :]
112 v_block = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
113 for base in range(0, N, BLOCK_COL_SIZE):
114 col_offset = base + tx
115 mask = col_offset < N and row_mask
116 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32)
117 v_block += v_value * v_value
119 normalized = tl.sqrt(tl.sum(v_block, axis=1) + eps)
120 tl.store(norm + row_offset, normalized[:, None], mask=row_mask)
121 g_value = tl.load(g + row_offset, mask=row_mask).to(tl.float32)
123 for base in range(0, N, BLOCK_COL_SIZE):
124 col_offset = base + tx
125 mask = col_offset < N and row_mask
126 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32)
127 v_vec = v_value / normalized[:, None]
128 out = v_vec * g_value
129 tl.store(output + row_offset * N + col_offset, out, mask=mask)
132def heur_block_m_weight_norm_bwd_kernel_last(args):
133 return 1
136def heur_block_n_weight_norm_bwd_kernel_last(args):
137 return triton.next_power_of_2(triton.cdiv(args["N"], 12))
140@libentry()
141# @triton.autotune(
142# configs=runtime.get_tuned_config("weight_norm_kernel_last"), key=["M", "N"]
143# )
144@triton.heuristics(
145 values={
146 "BLOCK_ROW_SIZE": heur_block_m_weight_norm_bwd_kernel_last,
147 "BLOCK_COL_SIZE": heur_block_n_weight_norm_bwd_kernel_last,
148 },
149)
150@triton.jit(do_not_specialize=["eps"])
151def weight_norm_bwd_kernel_last(
152 v_grad,
153 g_grad,
154 w,
155 v,
156 g,
157 norm,
158 M: tl.constexpr,
159 N: tl.constexpr,
160 eps,
161 BLOCK_ROW_SIZE: tl.constexpr,
162 BLOCK_COL_SIZE: tl.constexpr,
163):
164 tx = tl.arange(0, BLOCK_COL_SIZE)[:, None]
165 bx = tle.program_id(axis=0) * BLOCK_COL_SIZE
166 col_offset = tx + bx
167 col_mask = col_offset < N
169 g_value = tl.load(g + col_offset, mask=col_mask).to(tl.float32)
170 norm_value = tl.load(norm + col_offset, mask=col_mask).to(tl.float32)
172 ty = tl.arange(0, BLOCK_ROW_SIZE)[None, :]
174 vw_block = tl.zeros([BLOCK_COL_SIZE, BLOCK_ROW_SIZE], dtype=tl.float32)
175 for base in range(0, M, BLOCK_ROW_SIZE):
176 row_offset = base + ty
177 mask = row_offset < M and col_mask
178 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32)
179 w_value = tl.load(w + row_offset * N + col_offset, mask=mask).to(tl.float32)
180 vw_block += v_value * w_value
181 vw_sum = tl.sum(vw_block, 1)[:, None]
183 for base in range(0, M, BLOCK_ROW_SIZE):
184 row_offset = base + ty
185 mask = row_offset < M and col_mask
186 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32)
187 w_value = tl.load(w + row_offset * N + col_offset, mask=mask).to(tl.float32)
188 v_grad_value = g_value * (
189 w_value / (norm_value + eps)
190 - v_value / (norm_value * norm_value * norm_value + eps) * vw_sum
191 )
192 tl.store(v_grad + row_offset * N + col_offset, v_grad_value, mask=mask)
194 g_grad_value = vw_sum / (norm_value + eps)
195 tl.store(g_grad + col_offset, g_grad_value, mask=col_mask)
198def heur_block_m_weight_norm_bwd_kernel_first(args):
199 return triton.next_power_of_2(triton.cdiv(args["M"], 12))
202def heur_block_n_weight_norm_bwd_kernel_first(args):
203 return 1
206@libentry()
207# @triton.autotune(
208# configs=runtime.get_tuned_config("weight_norm_kernel_first"), key=["M", "N"]
209# )
210@triton.heuristics(
211 values={
212 "BLOCK_ROW_SIZE": heur_block_m_weight_norm_bwd_kernel_first,
213 "BLOCK_COL_SIZE": heur_block_n_weight_norm_bwd_kernel_first,
214 },
215)
216@triton.jit(do_not_specialize=["eps"])
217def weight_norm_bwd_kernel_first(
218 v_grad,
219 g_grad,
220 w,
221 v,
222 g,
223 norm,
224 M,
225 N,
226 eps,
227 BLOCK_ROW_SIZE: tl.constexpr,
228 BLOCK_COL_SIZE: tl.constexpr,
229):
230 ty = tl.arange(0, BLOCK_ROW_SIZE)[:, None]
231 by = tle.program_id(axis=0) * BLOCK_ROW_SIZE
232 row_offset = by + ty
233 row_mask = row_offset < M
235 g_value = tl.load(g + row_offset, mask=row_mask).to(tl.float32)
236 norm_value = tl.load(norm + row_offset, mask=row_mask).to(tl.float32)
238 tx = tl.arange(0, BLOCK_COL_SIZE)[None, :]
240 v_block = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
241 for base in range(0, N, BLOCK_COL_SIZE):
242 col_offset = base + tx
243 mask = col_offset < N and row_mask
244 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32)
245 w_value = tl.load(w + row_offset * N + col_offset, mask=mask).to(tl.float32)
246 v_block += v_value * w_value
247 vw_sum = tl.sum(v_block, 1)[:, None]
249 for base in range(0, N, BLOCK_COL_SIZE):
250 col_offset = base + tx
251 mask = col_offset < N and row_mask
252 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32)
253 w_value = tl.load(w + row_offset * N + col_offset, mask=mask).to(tl.float32)
254 v_grad_value = g_value * (
255 w_value / (norm_value + eps)
256 - v_value / (norm_value * norm_value * norm_value + eps) * vw_sum
257 )
258 tl.store(v_grad + row_offset * N + col_offset, v_grad_value, mask=mask)
260 g_grad_value = vw_sum / (norm_value + eps)
261 tl.store(g_grad + row_offset, g_grad_value, mask=row_mask)
264def weight_norm_interface(v, g, dim=0):
265 logger.debug("GEMS WEIGHT NORM INTERFACE FORWARD")
266 v = v.contiguous()
267 g = g.contiguous()
268 output = torch.empty_like(v)
269 norm = torch.empty_like(g)
270 if dim == 0:
271 M = v.shape[0]
272 N = math.prod(v.shape[1:])
273 grid = lambda META: (triton.cdiv(M, META["BLOCK_ROW_SIZE"]),)
274 with torch_device_fn.device(v.device):
275 weight_norm_kernel_first[grid](
276 output, norm, v, g, M, N, eps=torch.finfo(torch.float32).tiny
277 )
278 elif dim == v.ndim - 1:
279 M = math.prod(v.shape[:-1])
280 N = v.shape[dim]
281 grid = lambda META: (triton.cdiv(N, META["BLOCK_COL_SIZE"]),)
282 with torch_device_fn.device(v.device):
283 weight_norm_kernel_last[grid](
284 output, norm, v, g, M, N, eps=torch.finfo(torch.float32).tiny
285 )
286 return output, norm
289def weight_norm_interface_backward(w_grad, saved_v, saved_g, saved_norms, dim):
290 logger.debug("GEMS WEIGHT NORM INTERFACE BACKWARD")
291 w_grad = w_grad.contiguous()
292 saved_v = saved_v.contiguous()
293 saved_g = saved_g.contiguous()
294 saved_norms = saved_norms.contiguous()
295 v_grad = torch.empty_like(saved_v)
296 g_grad = torch.empty_like(saved_g)
298 if dim == 0:
299 M = saved_v.shape[0]
300 N = math.prod(saved_v.shape[1:])
301 grid = lambda META: (triton.cdiv(M, META["BLOCK_ROW_SIZE"]),)
302 with torch_device_fn.device(saved_v.device):
303 weight_norm_bwd_kernel_first[grid](
304 v_grad,
305 g_grad,
306 w_grad,
307 saved_v,
308 saved_g,
309 saved_norms,
310 M,
311 N,
312 eps=torch.finfo(torch.float32).tiny,
313 )
314 elif dim == saved_v.ndim - 1:
315 M = math.prod(saved_v.shape[:dim])
316 N = saved_v.shape[dim]
317 grid = lambda META: (triton.cdiv(N, META["BLOCK_COL_SIZE"]),)
318 with torch_device_fn.device(saved_v.device):
319 weight_norm_bwd_kernel_last[grid](
320 v_grad,
321 g_grad,
322 w_grad,
323 saved_v,
324 saved_g,
325 saved_norms,
326 M,
327 N,
328 eps=torch.finfo(torch.float32).tiny,
329 )
330 return v_grad, g_grad