Coverage for src/flag_gems/fused/weight_norm.py: 22%
124 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.ops import weight_norm_interface, weight_norm_interface_backward
10from flag_gems.runtime import torch_device_fn
11from flag_gems.utils import libentry
12from flag_gems.utils import triton_lang_extension as tle
14logger = logging.getLogger(__name__)
17@libentry()
18@triton.autotune(
19 configs=runtime.get_tuned_config("weight_norm_kernel"),
20 key=["v_shape0", "v_shape1", "v_shape2"],
21)
22@triton.jit(do_not_specialize=["eps"])
23def weight_norm_except_dim_kernel(
24 output,
25 norm,
26 v,
27 g,
28 v_shape0,
29 v_shape1,
30 v_shape2,
31 eps,
32 BLOCK_ROW_SIZE: tl.constexpr,
33 BLOCK_COL_SIZE: tl.constexpr,
34):
35 tid_m = tl.arange(0, BLOCK_ROW_SIZE)[:, None]
36 pid = tle.program_id(axis=0) * BLOCK_ROW_SIZE
37 row_offset = pid + tid_m
38 row_mask = row_offset < v_shape1
40 tid_n = tl.arange(0, BLOCK_COL_SIZE)[None, :]
41 v_block = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
42 for base in range(0, v_shape0 * v_shape2, BLOCK_COL_SIZE):
43 col_offset = base + tid_n
44 m_idx = col_offset // v_shape2
45 n_idx = row_offset
46 k_idx = col_offset % v_shape2
48 mask = m_idx < v_shape0 and row_mask
50 v_offsets = m_idx * v_shape1 * v_shape2 + n_idx * v_shape2 + k_idx
51 v_value = tl.load(v + v_offsets, mask=mask)
52 v_block += v_value * v_value
53 v_sum = tl.sum(v_block, axis=1) + eps
54 v_norm = tl.sqrt(v_sum[:, None])
55 tl.store(norm + row_offset, v_norm, mask=row_mask)
57 g_value = tl.load(g + row_offset, mask=row_mask)
58 for base in range(0, v_shape0 * v_shape2, BLOCK_COL_SIZE):
59 col_offset = base + tid_n
60 m_idx = col_offset // v_shape2
61 n_idx = row_offset
62 k_idx = col_offset % v_shape2
64 mask = m_idx < v_shape0 and row_mask
66 v_offsets = m_idx * v_shape1 * v_shape2 + n_idx * v_shape2 + k_idx
67 v_value = tl.load(v + v_offsets, mask=mask)
68 out = v_value * g_value / v_norm
69 tl.store(output + v_offsets, out, mask=mask)
72@libentry()
73@triton.autotune(
74 configs=runtime.get_tuned_config("weight_norm_kernel"),
75 key=["v_shape0", "v_shape1", "v_shape2"],
76)
77@triton.jit(do_not_specialize=["eps"])
78def weight_norm_except_dim_bwd_kernel(
79 v_grad,
80 g_grad,
81 grad,
82 v,
83 g,
84 norm,
85 v_shape0,
86 v_shape1,
87 v_shape2,
88 eps,
89 BLOCK_ROW_SIZE: tl.constexpr,
90 BLOCK_COL_SIZE: tl.constexpr,
91):
92 tid_m = tl.arange(0, BLOCK_ROW_SIZE)[:, None]
93 pid = tle.program_id(axis=0) * BLOCK_ROW_SIZE
94 row_offset = pid + tid_m
95 row_mask = row_offset < v_shape1
97 g_value = tl.load(g + row_offset, mask=row_mask).to(tl.float32)
98 norm_value = tl.load(norm + row_offset, mask=row_mask).to(tl.float32)
100 tid_n = tl.arange(0, BLOCK_COL_SIZE)[None, :]
102 v_block = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
103 for base in range(0, v_shape0 * v_shape2, BLOCK_COL_SIZE):
104 col_offset = base + tid_n
105 m_idx = col_offset // v_shape2
106 n_idx = row_offset
107 k_idx = col_offset % v_shape2
109 mask = m_idx < v_shape0 and row_mask
111 v_offsets = m_idx * v_shape1 * v_shape2 + n_idx * v_shape2 + k_idx
112 v_value = tl.load(v + v_offsets, mask=mask).to(tl.float32)
113 grad_value = tl.load(grad + v_offsets, mask=mask).to(tl.float32)
114 v_block += v_value * grad_value
115 vw_sum = tl.sum(v_block, axis=1)[:, None]
117 for base in range(0, v_shape0 * v_shape2, BLOCK_COL_SIZE):
118 col_offset = base + tid_n
119 m_idx = col_offset // v_shape2
120 n_idx = row_offset
121 k_idx = col_offset % v_shape2
123 mask = m_idx < v_shape0 and row_mask
125 v_offsets = m_idx * v_shape1 * v_shape2 + n_idx * v_shape2 + k_idx
126 v_value = tl.load(v + v_offsets, mask=mask).to(tl.float32)
127 grad_value = tl.load(grad + v_offsets, mask=mask).to(tl.float32)
128 v_grad_value = g_value * (
129 grad_value / (norm_value + eps)
130 - v_value / (norm_value * norm_value * norm_value + eps) * vw_sum
131 )
132 tl.store(v_grad + v_offsets, v_grad_value, mask=mask)
134 g_grad_value = vw_sum / (norm_value + eps)
135 tl.store(g_grad + row_offset, g_grad_value, mask=row_mask)
138def weight_norm_except_dim(v, g, dim):
139 logger.debug("GEMS WEIGHT NORM EXCEPT DIM FORWARD")
140 v = v.contiguous()
141 output = torch.empty_like(v)
142 norm = torch.empty_like(g, dtype=torch.float32)
143 v_shape = [
144 math.prod(v.shape[:dim]),
145 v.shape[dim],
146 math.prod(v.shape[dim + 1 :]),
147 ]
149 grid = lambda META: (triton.cdiv(v_shape[1], META["BLOCK_ROW_SIZE"]),)
151 with torch_device_fn.device(v.device):
152 weight_norm_except_dim_kernel[grid](
153 output,
154 norm,
155 v,
156 g,
157 v_shape[0],
158 v_shape[1],
159 v_shape[2],
160 eps=torch.finfo(torch.float32).tiny,
161 )
162 return output, norm
165def weight_norm_except_dim_backward(grad, v, g, norm, dim):
166 logger.debug("GEMS WEIGHT NORM EXCEPT DIM BACKWARD")
167 grad = grad.contiguous()
168 v_grad = torch.empty_like(v)
169 g_grad = torch.empty_like(g)
170 v_shape = [
171 math.prod(v.shape[:dim]),
172 v.shape[dim],
173 math.prod(v.shape[dim + 1 :]),
174 ]
176 grid = lambda META: (triton.cdiv(v_shape[1], META["BLOCK_ROW_SIZE"]),)
177 with torch_device_fn.device(v.device):
178 weight_norm_except_dim_bwd_kernel[grid](
179 v_grad,
180 g_grad,
181 grad,
182 v,
183 g,
184 norm,
185 *v_shape,
186 eps=torch.finfo(torch.float32).tiny,
187 )
188 return v_grad, g_grad
191class WeightNorm(torch.autograd.Function):
192 @staticmethod
193 def forward(ctx, v, g, dim=0):
194 logger.debug("GEMS WEIGHT NORM")
195 dim = dim % v.ndim
196 can_use_fused = dim == 0 or dim == v.ndim - 1
197 if can_use_fused:
198 output, norm = weight_norm_interface(v, g, dim)
199 else:
200 output, norm = weight_norm_except_dim(v, g, dim)
201 ctx.save_for_backward(v, g, norm)
202 ctx.dim = dim
203 ctx.can_use_fused = can_use_fused
204 return output
206 @staticmethod
207 def backward(ctx, grad):
208 logger.debug("GEMS WEIGHT NORM BACKWARD")
209 v, g, norm = ctx.saved_tensors
210 dim = ctx.dim
211 if ctx.can_use_fused:
212 v_grad, g_grad = weight_norm_interface_backward(grad, v, g, norm, dim)
213 else:
214 v_grad, g_grad = weight_norm_except_dim_backward(grad, v, g, norm, dim)
215 return v_grad, g_grad, None
218def weight_norm(v, g, dim=0):
219 return WeightNorm.apply(v, g, dim)