Coverage for src/flag_gems/ops/weightnorm.py: 29%
160 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry, tl_extra_shim
11from flag_gems.utils import triton_lang_extension as tle
13logger = logging.getLogger(__name__)
16@libentry()
17@triton.autotune(
18 configs=runtime.get_tuned_config("weight_norm_kernel_last"), key=["M", "N"]
19)
20@triton.jit(do_not_specialize=["eps"])
21def weight_norm_kernel_last(
22 output,
23 norm,
24 v,
25 g,
26 M,
27 N,
28 eps,
29 BLOCK_ROW_SIZE: tl.constexpr,
30 BLOCK_COL_SIZE: tl.constexpr,
31):
32 tx = tl.arange(0, BLOCK_COL_SIZE)[:, None]
33 bx = tle.program_id(axis=0) * BLOCK_COL_SIZE
34 col_offset = bx + tx
35 col_mask = col_offset < N
37 ty = tl.arange(0, BLOCK_ROW_SIZE)[None, :]
38 v_block = tl.zeros([BLOCK_COL_SIZE, BLOCK_ROW_SIZE], dtype=tl.float32)
39 for base in range(0, M, BLOCK_ROW_SIZE):
40 row_offset = base + ty
41 mask = row_offset < M and col_mask
42 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32)
43 v_block += v_value * v_value
45 normalized = tl.sqrt(tl.sum(v_block, axis=1) + eps)
46 tl.store(norm + col_offset, normalized[:, None], mask=col_mask)
47 g_value = tl.load(g + col_offset, mask=col_mask).to(tl.float32)
49 for base in range(0, M, BLOCK_ROW_SIZE):
50 row_offset = base + ty
51 mask = row_offset < M and col_mask
52 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32)
53 v_vec = v_value / normalized[:, None]
54 out = v_vec * g_value
55 tl.store(output + row_offset * N + col_offset, out, mask=mask)
58@libentry()
59@triton.autotune(
60 configs=runtime.get_tuned_config("weight_norm_kernel_first"), key=["M", "N"]
61)
62@triton.jit(do_not_specialize=["eps"])
63def weight_norm_kernel_first(
64 output,
65 norm,
66 v,
67 g,
68 M,
69 N,
70 eps,
71 BLOCK_ROW_SIZE: tl.constexpr,
72 BLOCK_COL_SIZE: tl.constexpr,
73):
74 ty = tl.arange(0, BLOCK_ROW_SIZE)[:, None]
75 by = tle.program_id(axis=0) * BLOCK_ROW_SIZE
76 row_offset = by + ty
77 row_mask = row_offset < M
79 tx = tl.arange(0, BLOCK_COL_SIZE)[None, :]
80 v_block = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
81 for base in range(0, N, BLOCK_COL_SIZE):
82 col_offset = base + tx
83 mask = col_offset < N and row_mask
84 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32)
85 v_block += v_value * v_value
87 normalized = tl.sqrt(tl.sum(v_block, axis=1) + eps)
88 tl.store(norm + row_offset, normalized[:, None], mask=row_mask)
89 g_value = tl.load(g + row_offset, mask=row_mask).to(tl.float32)
91 for base in range(0, N, BLOCK_COL_SIZE):
92 col_offset = base + tx
93 mask = col_offset < N and row_mask
94 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32)
95 v_vec = v_value / normalized[:, None]
96 out = v_vec * g_value
97 tl.store(output + row_offset * N + col_offset, out, mask=mask)
100@libentry()
101@triton.autotune(
102 configs=runtime.get_tuned_config("weight_norm_kernel_last"), key=["M", "N"]
103)
104@triton.jit(do_not_specialize=["eps"])
105def weight_norm_bwd_kernel_last(
106 v_grad,
107 g_grad,
108 w,
109 v,
110 g,
111 norm,
112 M,
113 N,
114 eps,
115 BLOCK_ROW_SIZE: tl.constexpr,
116 BLOCK_COL_SIZE: tl.constexpr,
117):
118 tx = tl.arange(0, BLOCK_COL_SIZE)[:, None]
119 bx = tle.program_id(axis=0) * BLOCK_COL_SIZE
120 col_offset = tx + bx
121 col_mask = col_offset < N
123 g_value = tl.load(g + col_offset, mask=col_mask).to(tl.float32)
124 norm_value = tl.load(norm + col_offset, mask=col_mask).to(tl.float32)
125 norm_1 = 1 / (norm_value + eps)
126 norm_3 = tl_extra_shim.pow(norm_1, 3)
128 ty = tl.arange(0, BLOCK_ROW_SIZE)[None, :]
130 vw_block = tl.zeros([BLOCK_COL_SIZE, BLOCK_ROW_SIZE], dtype=tl.float32)
131 for base in range(0, M, BLOCK_ROW_SIZE):
132 row_offset = base + ty
133 mask = row_offset < M and col_mask
134 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32)
135 w_value = tl.load(w + row_offset * N + col_offset, mask=mask).to(tl.float32)
136 vw_block += v_value * w_value
137 vw_sum = tl.sum(vw_block, 1)[:, None]
139 for base in range(0, M, BLOCK_ROW_SIZE):
140 row_offset = base + ty
141 mask = row_offset < M and col_mask
142 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32)
143 w_value = tl.load(w + row_offset * N + col_offset, mask=mask).to(tl.float32)
144 v_grad_value = g_value * (w_value * norm_1 - v_value * norm_3 * vw_sum)
145 tl.store(v_grad + row_offset * N + col_offset, v_grad_value, mask=mask)
147 g_grad_value = vw_sum / (norm_value + eps)
148 tl.store(g_grad + col_offset, g_grad_value, mask=col_mask)
151@libentry()
152@triton.autotune(
153 configs=runtime.get_tuned_config("weight_norm_kernel_first"), key=["M", "N"]
154)
155@triton.jit(do_not_specialize=["eps"])
156def weight_norm_bwd_kernel_first(
157 v_grad,
158 g_grad,
159 w,
160 v,
161 g,
162 norm,
163 M,
164 N,
165 eps,
166 BLOCK_ROW_SIZE: tl.constexpr,
167 BLOCK_COL_SIZE: tl.constexpr,
168):
169 ty = tl.arange(0, BLOCK_ROW_SIZE)[:, None]
170 by = tle.program_id(axis=0) * BLOCK_ROW_SIZE
171 row_offset = by + ty
172 row_mask = row_offset < M
174 g_value = tl.load(g + row_offset, mask=row_mask).to(tl.float32)
175 norm_value = tl.load(norm + row_offset, mask=row_mask).to(tl.float32)
176 norm_1 = 1 / (norm_value + eps)
177 norm_3 = tl_extra_shim.pow(norm_1, 3)
179 tx = tl.arange(0, BLOCK_COL_SIZE)[None, :]
181 v_block = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
182 for base in range(0, N, BLOCK_COL_SIZE):
183 col_offset = base + tx
184 mask = col_offset < N and row_mask
185 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32)
186 w_value = tl.load(w + row_offset * N + col_offset, mask=mask).to(tl.float32)
187 v_block += v_value * w_value
188 vw_sum = tl.sum(v_block, 1)[:, None]
190 for base in range(0, N, BLOCK_COL_SIZE):
191 col_offset = base + tx
192 mask = col_offset < N and row_mask
193 v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32)
194 w_value = tl.load(w + row_offset * N + col_offset, mask=mask).to(tl.float32)
195 v_grad_value = g_value * (w_value * norm_1 - v_value * norm_3 * vw_sum)
196 tl.store(v_grad + row_offset * N + col_offset, v_grad_value, mask=mask)
198 g_grad_value = vw_sum / (norm_value + eps)
199 tl.store(g_grad + row_offset, g_grad_value, mask=row_mask)
202def weight_norm_interface(v, g, dim=0):
203 logger.debug("GEMS WEIGHT NORM INTERFACE FORWARD")
204 v = v.contiguous()
205 g = g.contiguous()
206 output = torch.empty_like(v)
207 norm = torch.empty_like(g)
208 if dim == 0:
209 M = v.shape[0]
210 N = math.prod(v.shape[1:])
211 grid = lambda META: (triton.cdiv(M, META["BLOCK_ROW_SIZE"]),)
212 with torch_device_fn.device(v.device):
213 weight_norm_kernel_first[grid](
214 output, norm, v, g, M, N, eps=torch.finfo(torch.float32).tiny
215 )
216 elif dim == v.ndim - 1:
217 M = math.prod(v.shape[:-1])
218 N = v.shape[dim]
219 grid = lambda META: (triton.cdiv(N, META["BLOCK_COL_SIZE"]),)
220 with torch_device_fn.device(v.device):
221 weight_norm_kernel_last[grid](
222 output, norm, v, g, M, N, eps=torch.finfo(torch.float32).tiny
223 )
224 return output, norm
227def weight_norm_interface_backward(w_grad, saved_v, saved_g, saved_norms, dim):
228 logger.debug("GEMS WEIGHT NORM INTERFACE BACKWARD")
229 w_grad = w_grad.contiguous()
230 saved_v = saved_v.contiguous()
231 saved_g = saved_g.contiguous()
232 saved_norms = saved_norms.contiguous()
233 v_grad = torch.empty_like(saved_v)
234 g_grad = torch.empty_like(saved_g)
236 if dim == 0:
237 M = saved_v.shape[0]
238 N = math.prod(saved_v.shape[1:])
239 grid = lambda META: (triton.cdiv(M, META["BLOCK_ROW_SIZE"]),)
240 with torch_device_fn.device(saved_v.device):
241 weight_norm_bwd_kernel_first[grid](
242 v_grad,
243 g_grad,
244 w_grad,
245 saved_v,
246 saved_g,
247 saved_norms,
248 M,
249 N,
250 eps=torch.finfo(torch.float32).tiny,
251 )
252 elif dim == saved_v.ndim - 1:
253 M = math.prod(saved_v.shape[:dim])
254 N = saved_v.shape[dim]
255 grid = lambda META: (triton.cdiv(N, META["BLOCK_COL_SIZE"]),)
256 with torch_device_fn.device(saved_v.device):
257 weight_norm_bwd_kernel_last[grid](
258 v_grad,
259 g_grad,
260 w_grad,
261 saved_v,
262 saved_g,
263 saved_norms,
264 M,
265 N,
266 eps=torch.finfo(torch.float32).tiny,
267 )
268 return v_grad, g_grad