Coverage for src/flag_gems/runtime/backend/_kunlunxin/fused/weight_norm.py: 0%
127 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as tle
12from ..ops import weight_norm_interface, weight_norm_interface_backward
14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
17def heur_row_weight_norm_except_dim_kernel(args):
18 return triton.next_power_of_2(triton.cdiv(args["v_shape1"], 12))
21def heur_col_weight_norm_except_dim_kernel(args):
22 return 1
25@libentry()
26# @triton.autotune(
27# configs=runtime.get_tuned_config("weight_norm_kernel"),
28# key=["v_shape0", "v_shape1", "v_shape2"],
29# )
30@triton.heuristics(
31 values={
32 "BLOCK_ROW_SIZE": heur_row_weight_norm_except_dim_kernel,
33 "BLOCK_COL_SIZE": heur_col_weight_norm_except_dim_kernel,
34 },
35)
36@triton.jit(do_not_specialize=["eps"])
37def weight_norm_except_dim_kernel(
38 output,
39 norm,
40 v,
41 g,
42 v_shape0,
43 v_shape1,
44 v_shape2,
45 eps,
46 BLOCK_ROW_SIZE: tl.constexpr,
47 BLOCK_COL_SIZE: tl.constexpr,
48):
49 tid_m = tl.arange(0, BLOCK_ROW_SIZE)[:, None]
50 pid = tle.program_id(axis=0) * BLOCK_ROW_SIZE
51 row_offset = pid + tid_m
52 row_mask = row_offset < v_shape1
54 tid_n = tl.arange(0, BLOCK_COL_SIZE)[None, :]
55 v_block = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
56 for base in range(0, v_shape0 * v_shape2, BLOCK_COL_SIZE):
57 col_offset = base + tid_n
58 m_idx = col_offset // v_shape2
59 n_idx = row_offset
60 k_idx = col_offset % v_shape2
62 mask = m_idx < v_shape0 and row_mask
64 v_offsets = m_idx * v_shape1 * v_shape2 + n_idx * v_shape2 + k_idx
65 v_value = tl.load(v + v_offsets, mask=mask)
66 v_block += v_value * v_value
67 v_sum = tl.sum(v_block, axis=1) + eps
68 v_norm = tl.sqrt(v_sum[:, None])
69 tl.store(norm + row_offset, v_norm, mask=row_mask)
71 g_value = tl.load(g + row_offset, mask=row_mask)
72 for base in range(0, v_shape0 * v_shape2, BLOCK_COL_SIZE):
73 col_offset = base + tid_n
74 m_idx = col_offset // v_shape2
75 n_idx = row_offset
76 k_idx = col_offset % v_shape2
78 mask = m_idx < v_shape0 and row_mask
80 v_offsets = m_idx * v_shape1 * v_shape2 + n_idx * v_shape2 + k_idx
81 v_value = tl.load(v + v_offsets, mask=mask)
82 out = v_value * g_value / v_norm
83 tl.store(output + v_offsets, out, mask=mask)
86@libentry()
87# @triton.autotune(
88# configs=runtime.get_tuned_config("weight_norm_kernel"),
89# key=["v_shape0", "v_shape1", "v_shape2"],
90# )
91@triton.heuristics(
92 values={
93 "BLOCK_ROW_SIZE": heur_row_weight_norm_except_dim_kernel,
94 "BLOCK_COL_SIZE": heur_col_weight_norm_except_dim_kernel,
95 },
96)
97@triton.jit(do_not_specialize=["eps"])
98def weight_norm_except_dim_bwd_kernel(
99 v_grad,
100 g_grad,
101 grad,
102 v,
103 g,
104 norm,
105 v_shape0,
106 v_shape1,
107 v_shape2,
108 eps,
109 BLOCK_ROW_SIZE: tl.constexpr,
110 BLOCK_COL_SIZE: tl.constexpr,
111):
112 tid_m = tl.arange(0, BLOCK_ROW_SIZE)[:, None]
113 pid = tle.program_id(axis=0) * BLOCK_ROW_SIZE
114 row_offset = pid + tid_m
115 row_mask = row_offset < v_shape1
117 g_value = tl.load(g + row_offset, mask=row_mask).to(tl.float32)
118 norm_value = tl.load(norm + row_offset, mask=row_mask).to(tl.float32)
120 tid_n = tl.arange(0, BLOCK_COL_SIZE)[None, :]
122 v_block = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
123 for base in range(0, v_shape0 * v_shape2, BLOCK_COL_SIZE):
124 col_offset = base + tid_n
125 m_idx = col_offset // v_shape2
126 n_idx = row_offset
127 k_idx = col_offset % v_shape2
129 mask = m_idx < v_shape0 and row_mask
131 v_offsets = m_idx * v_shape1 * v_shape2 + n_idx * v_shape2 + k_idx
132 v_value = tl.load(v + v_offsets, mask=mask).to(tl.float32)
133 grad_value = tl.load(grad + v_offsets, mask=mask).to(tl.float32)
134 v_block += v_value * grad_value
135 vw_sum = tl.sum(v_block, axis=1)[:, None]
137 for base in range(0, v_shape0 * v_shape2, BLOCK_COL_SIZE):
138 col_offset = base + tid_n
139 m_idx = col_offset // v_shape2
140 n_idx = row_offset
141 k_idx = col_offset % v_shape2
143 mask = m_idx < v_shape0 and row_mask
145 v_offsets = m_idx * v_shape1 * v_shape2 + n_idx * v_shape2 + k_idx
146 v_value = tl.load(v + v_offsets, mask=mask).to(tl.float32)
147 grad_value = tl.load(grad + v_offsets, mask=mask).to(tl.float32)
148 v_grad_value = g_value * (
149 grad_value / (norm_value + eps)
150 - v_value / (norm_value * norm_value * norm_value + eps) * vw_sum
151 )
152 tl.store(v_grad + v_offsets, v_grad_value, mask=mask)
154 g_grad_value = vw_sum / (norm_value + eps)
155 tl.store(g_grad + row_offset, g_grad_value, mask=row_mask)
158def weight_norm_except_dim(v, g, dim):
159 logger.debug("GEMS WEIGHT NORM EXCEPT DIM FORWARD")
160 v = v.contiguous()
161 output = torch.empty_like(v)
162 norm = torch.empty_like(g, dtype=torch.float32)
163 v_shape = [
164 math.prod(v.shape[:dim]),
165 v.shape[dim],
166 math.prod(v.shape[dim + 1 :]),
167 ]
169 grid = lambda META: (triton.cdiv(v_shape[1], META["BLOCK_ROW_SIZE"]),)
171 with torch_device_fn.device(v.device):
172 weight_norm_except_dim_kernel[grid](
173 output,
174 norm,
175 v,
176 g,
177 v_shape[0],
178 v_shape[1],
179 v_shape[2],
180 eps=torch.finfo(torch.float32).tiny,
181 )
182 return output, norm
185def weight_norm_except_dim_backward(grad, v, g, norm, dim):
186 logger.debug("GEMS WEIGHT NORM EXCEPT DIM BACKWARD")
187 grad = grad.contiguous()
188 v_grad = torch.empty_like(v)
189 g_grad = torch.empty_like(g)
190 v_shape = [
191 math.prod(v.shape[:dim]),
192 v.shape[dim],
193 math.prod(v.shape[dim + 1 :]),
194 ]
196 grid = lambda META: (triton.cdiv(v_shape[1], META["BLOCK_ROW_SIZE"]),)
197 with torch_device_fn.device(v.device):
198 weight_norm_except_dim_bwd_kernel[grid](
199 v_grad,
200 g_grad,
201 grad,
202 v,
203 g,
204 norm,
205 *v_shape,
206 eps=torch.finfo(torch.float32).tiny,
207 )
208 return v_grad, g_grad
211class WeightNorm(torch.autograd.Function):
212 @staticmethod
213 def forward(ctx, v, g, dim=0):
214 logger.debug("GEMS WEIGHT NORM")
215 dim = dim % v.ndim
216 can_use_fused = dim == 0 or dim == v.ndim - 1
217 if can_use_fused:
218 output, norm = weight_norm_interface(v, g, dim)
219 else:
220 output, norm = weight_norm_except_dim(v, g, dim)
221 ctx.save_for_backward(v, g, norm)
222 ctx.dim = dim
223 ctx.can_use_fused = can_use_fused
224 return output
226 @staticmethod
227 def backward(ctx, grad):
228 logger.debug("GEMS WEIGHT NORM BACKWARD")
229 v, g, norm = ctx.saved_tensors
230 dim = ctx.dim
231 if ctx.can_use_fused:
232 v_grad, g_grad = weight_norm_interface_backward(grad, v, g, norm, dim)
233 else:
234 v_grad, g_grad = weight_norm_except_dim_backward(grad, v, g, norm, dim)
235 return v_grad, g_grad, None
238def weight_norm(v, g, dim=0):
239 return WeightNorm.apply(v, g, dim)