Coverage for src/flag_gems/runtime/backend/_cambricon/fused/weight_norm.py: 0%
124 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
12from ..ops import weight_norm_interface, weight_norm_interface_backward
14logger = logging.getLogger(__name__)
16MAX_N = 31744
19@libentry()
20@triton.autotune(
21 configs=runtime.get_tuned_config("weight_norm_kernel"),
22 key=["v_shape0", "v_shape1", "v_shape2"],
23)
24@triton.jit(do_not_specialize=["eps"])
25def weight_norm_except_dim_kernel(
26 output,
27 norm,
28 v,
29 g,
30 v_shape0,
31 v_shape1,
32 v_shape2,
33 eps,
34 BLOCK_ROW_SIZE: tl.constexpr,
35 BLOCK_COL_SIZE: tl.constexpr,
36):
37 tid_m = tl.arange(0, BLOCK_ROW_SIZE)[:, None]
38 pid = tl.program_id(axis=0) * BLOCK_ROW_SIZE
39 row_offset = pid + tid_m
40 row_mask = row_offset < v_shape1
42 tid_n = tl.arange(0, BLOCK_COL_SIZE)[None, :]
43 v_block = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
44 for base in range(0, v_shape0 * v_shape2, BLOCK_COL_SIZE):
45 col_offset = base + tid_n
46 m_idx = col_offset // v_shape2
47 n_idx = row_offset
48 k_idx = col_offset % v_shape2
50 mask = m_idx < v_shape0 and row_mask
52 v_offsets = m_idx * v_shape1 * v_shape2 + n_idx * v_shape2 + k_idx
53 v_value = tl.load(v + v_offsets, mask=mask)
54 v_block += v_value * v_value
55 v_sum = tl.sum(v_block, axis=1) + eps
56 v_norm = tl.sqrt(v_sum[:, None])
57 tl.store(norm + row_offset, v_norm, mask=row_mask)
59 g_value = tl.load(g + row_offset, mask=row_mask)
60 for base in range(0, v_shape0 * v_shape2, BLOCK_COL_SIZE):
61 col_offset = base + tid_n
62 m_idx = col_offset // v_shape2
63 n_idx = row_offset
64 k_idx = col_offset % v_shape2
66 mask = m_idx < v_shape0 and row_mask
68 v_offsets = m_idx * v_shape1 * v_shape2 + n_idx * v_shape2 + k_idx
69 v_value = tl.load(v + v_offsets, mask=mask)
70 out = v_value * g_value / v_norm
71 tl.store(output + v_offsets, out, mask=mask)
74@libentry()
75@triton.autotune(
76 configs=runtime.get_tuned_config("weight_norm_kernel"),
77 key=["v_shape0", "v_shape1", "v_shape2"],
78)
79@triton.jit(do_not_specialize=["eps"])
80def weight_norm_except_dim_bwd_kernel(
81 v_grad,
82 g_grad,
83 grad,
84 v,
85 g,
86 norm,
87 v_shape0,
88 v_shape1,
89 v_shape2,
90 eps,
91 BLOCK_ROW_SIZE: tl.constexpr,
92 BLOCK_COL_SIZE: tl.constexpr,
93):
94 tid_m = tl.arange(0, BLOCK_ROW_SIZE)[:, None]
95 pid = tl.program_id(axis=0) * BLOCK_ROW_SIZE
96 row_offset = pid + tid_m
97 row_mask = row_offset < v_shape1
99 g_value = tl.load(g + row_offset, mask=row_mask).to(tl.float32)
100 norm_value = tl.load(norm + row_offset, mask=row_mask).to(tl.float32)
102 tid_n = tl.arange(0, BLOCK_COL_SIZE)[None, :]
104 v_block = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
105 for base in range(0, v_shape0 * v_shape2, BLOCK_COL_SIZE):
106 col_offset = base + tid_n
107 m_idx = col_offset // v_shape2
108 n_idx = row_offset
109 k_idx = col_offset % v_shape2
111 mask = m_idx < v_shape0 and row_mask
113 v_offsets = m_idx * v_shape1 * v_shape2 + n_idx * v_shape2 + k_idx
114 v_value = tl.load(v + v_offsets, mask=mask).to(tl.float32)
115 grad_value = tl.load(grad + v_offsets, mask=mask).to(tl.float32)
116 v_block += v_value * grad_value
117 vw_sum = tl.sum(v_block, axis=1)[:, None]
119 for base in range(0, v_shape0 * v_shape2, BLOCK_COL_SIZE):
120 col_offset = base + tid_n
121 m_idx = col_offset // v_shape2
122 n_idx = row_offset
123 k_idx = col_offset % v_shape2
125 mask = m_idx < v_shape0 and row_mask
127 v_offsets = m_idx * v_shape1 * v_shape2 + n_idx * v_shape2 + k_idx
128 v_value = tl.load(v + v_offsets, mask=mask).to(tl.float32)
129 grad_value = tl.load(grad + v_offsets, mask=mask).to(tl.float32)
130 v_grad_value = g_value * (
131 grad_value / (norm_value + eps)
132 - v_value / (norm_value * norm_value * norm_value + eps) * vw_sum
133 )
134 tl.store(v_grad + v_offsets, v_grad_value, mask=mask)
136 g_grad_value = vw_sum / (norm_value + eps)
137 tl.store(g_grad + row_offset, g_grad_value, mask=row_mask)
140def weight_norm_except_dim(v, g, dim):
141 logger.debug("GEMS_CAMBRICON WEIGHT NORM EXCEPT DIM FORWARD")
142 v = v.contiguous()
143 output = torch.empty_like(v)
144 norm = torch.empty_like(g, dtype=torch.float32)
145 v_shape = [
146 math.prod(v.shape[:dim]),
147 v.shape[dim],
148 math.prod(v.shape[dim + 1 :]),
149 ]
151 grid = lambda META: (triton.cdiv(v_shape[1], META["BLOCK_ROW_SIZE"]),)
153 with torch_device_fn.device(v.device):
154 weight_norm_except_dim_kernel[grid](
155 output,
156 norm,
157 v,
158 g,
159 v_shape[0],
160 v_shape[1],
161 v_shape[2],
162 eps=torch.finfo(torch.float32).tiny,
163 )
164 return output, norm
167def weight_norm_except_dim_backward(grad, v, g, norm, dim):
168 logger.debug("GEMS_CAMBRICON NORM BACKWARD")
169 grad = grad.contiguous()
170 v_grad = torch.empty_like(v)
171 g_grad = torch.empty_like(g)
172 v_shape = [
173 math.prod(v.shape[:dim]),
174 v.shape[dim],
175 math.prod(v.shape[dim + 1 :]),
176 ]
178 grid = lambda META: (triton.cdiv(v_shape[1], META["BLOCK_ROW_SIZE"]),)
179 with torch_device_fn.device(v.device):
180 weight_norm_except_dim_bwd_kernel[grid](
181 v_grad,
182 g_grad,
183 grad,
184 v,
185 g,
186 norm,
187 *v_shape,
188 eps=torch.finfo(torch.float32).tiny,
189 )
190 return v_grad, g_grad
193class WeightNorm(torch.autograd.Function):
194 @staticmethod
195 def forward(ctx, v, g, dim=0):
196 logger.debug("GEMS_CAMBRICON WEIGHT NORM")
197 dim = dim % v.ndim
198 can_use_fused = dim == 0 or dim == v.ndim - 1
199 if can_use_fused:
200 output, norm = weight_norm_interface(v, g, dim)
201 else:
202 output, norm = weight_norm_except_dim(v, g, dim)
203 ctx.save_for_backward(v, g, norm)
204 ctx.dim = dim
205 ctx.can_use_fused = can_use_fused
206 return output
208 @staticmethod
209 def backward(ctx, grad):
210 logger.debug("GEMS_CAMBRICON WEIGHT NORM BACKWARD")
211 v, g, norm = ctx.saved_tensors
212 dim = ctx.dim
213 if ctx.can_use_fused:
214 v_grad, g_grad = weight_norm_interface_backward(grad, v, g, norm, dim)
215 else:
216 v_grad, g_grad = weight_norm_except_dim_backward(grad, v, g, norm, dim)
217 return v_grad, g_grad, None
220def weight_norm(v, g, dim=0):
221 return WeightNorm.apply(v, g, dim)