Coverage for src/flag_gems/ops/vector_norm.py: 24%
259 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import dim_compress, libentry, tl_extra_shim
11from flag_gems.utils import triton_lang_extension as tle
13pow = tl_extra_shim.pow
14logger = logging.getLogger(__name__)
17@libentry()
18@triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"])
19@triton.jit
20def l2_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
21 pid = tle.program_id(0).to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
22 X = X + pid * N
23 Out = Out + pid
24 row_mask = pid < M
26 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
27 for off in range(0, N, BLOCK_N):
28 cols = off + tl.arange(0, BLOCK_N)[None, :]
29 col_mask = cols < N
30 mask = row_mask and col_mask
32 a = tl.load(X + cols, mask, other=0.0).to(tl.float32)
33 _sum += a * a
34 sum = tl.sum(_sum, axis=1)
36 out = tl.sqrt(sum)[:, None]
37 tl.store(Out, out, row_mask)
40@libentry()
41@triton.jit
42def l2_norm_kernel_1(X, Mid, M, BLOCK_SIZE: tl.constexpr):
43 pid = tle.program_id(0).to(tl.int64)
44 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
45 X = X + offset
46 Mid = Mid + pid
47 mask = offset < M
49 x = tl.load(X, mask=mask, other=0.0).to(tl.float32)
50 mid = tl.sum(x * x)
51 tl.store(Mid, mid)
54@libentry()
55@triton.jit
56def l2_norm_kernel_2(Mid, Out, MID_SIZE, BLOCK_MID: tl.constexpr):
57 offset = tl.arange(0, BLOCK_MID)
58 Mid = Mid + offset
59 mask = offset < MID_SIZE
60 mid = tl.load(Mid, mask=mask, other=0.0).to(tl.float32)
61 out = tl.sqrt(tl.sum(mid))
62 tl.store(Out, out)
65@libentry()
66@triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"])
67@triton.jit
68def max_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
69 pid = tle.program_id(0).to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
70 X = X + pid * N
71 Out = Out + pid
72 row_mask = pid < M
74 _max = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
75 for off in range(0, N, BLOCK_N):
76 cols = off + tl.arange(0, BLOCK_N)[None, :]
77 col_mask = cols < N
78 mask = row_mask and col_mask
80 a = tl.load(X + cols, mask, other=0.0).to(tl.float32)
81 _max = tl.maximum(tl.abs(a), _max)
83 max = tl.max(_max, axis=1)
84 out = max[:, None]
85 tl.store(Out, out, row_mask)
88@libentry()
89@triton.jit
90def max_norm_kernel_1(X, Mid, M, BLOCK_SIZE: tl.constexpr):
91 pid = tle.program_id(0).to(tl.int64)
92 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
93 X = X + offset
94 Mid = Mid + pid
95 mask = offset < M
97 x = tl.load(X, mask=mask, other=0.0).to(tl.float32)
98 mid = tl.max(tl.abs(x))
99 tl.store(Mid, mid)
102@libentry()
103@triton.jit
104def max_norm_kernel_2(Mid, Out, MID_SIZE, BLOCK_MID: tl.constexpr):
105 offset = tl.arange(0, BLOCK_MID)
106 Mid = Mid + offset
107 mask = offset < MID_SIZE
108 mid = tl.load(Mid, mask=mask, other=0.0).to(tl.float32)
109 out = tl.max(mid)
110 tl.store(Out, out)
113@libentry()
114@triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"])
115@triton.jit
116def min_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
117 pid = tle.program_id(0).to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
118 X = X + pid * N
119 Out = Out + pid
120 row_mask = pid < M
122 _min = tl.full([BLOCK_M, BLOCK_N], value=float("inf"), dtype=tl.float32)
123 for off in range(0, N, BLOCK_N):
124 cols = off + tl.arange(0, BLOCK_N)[None, :]
125 col_mask = cols < N
126 mask = row_mask and col_mask
128 a = tl.load(X + cols, mask, other=float("inf")).to(tl.float32)
129 _min = tl.minimum(tl.abs(a), _min)
131 min = tl.min(_min, axis=1)
132 out = min[:, None]
133 tl.store(Out, out, row_mask)
136@libentry()
137@triton.jit
138def min_norm_kernel_1(X, Mid, M, BLOCK_SIZE: tl.constexpr):
139 pid = tle.program_id(0).to(tl.int64)
140 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
141 X = X + offset
142 Mid = Mid + pid
143 mask = offset < M
145 x = tl.load(X, mask=mask, other=float("inf")).to(tl.float32)
146 mid = tl.min(tl.abs(x))
147 tl.store(Mid, mid)
150@libentry()
151@triton.jit
152def min_norm_kernel_2(Mid, Out, MID_SIZE, BLOCK_MID: tl.constexpr):
153 offset = tl.arange(0, BLOCK_MID)
154 Mid = Mid + offset
155 mask = offset < MID_SIZE
156 mid = tl.load(Mid, mask=mask, other=float("inf")).to(tl.float32)
157 out = tl.min(mid)
158 tl.store(Out, out)
161@libentry()
162@triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"])
163@triton.jit
164def l0_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
165 pid = tle.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
166 X = X + pid * N
167 Out = Out + pid
168 row_mask = pid < M
170 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
171 for off in range(0, N, BLOCK_N):
172 cols = off + tl.arange(0, BLOCK_N)[None, :]
173 col_mask = cols < N
174 mask = row_mask and col_mask
176 a = tl.load(X + cols, mask, other=0).to(tl.float32)
177 _sum += tl.where(a != 0, 1, 0)
178 sum = tl.sum(_sum, axis=1)
179 out = sum[:, None]
180 tl.store(Out, out, row_mask)
183@libentry()
184@triton.jit
185def l0_norm_kernel_1(X, Mid, M, BLOCK_SIZE: tl.constexpr):
186 pid = tle.program_id(0).to(tl.int64)
187 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
188 X = X + offset
189 Mid = Mid + pid
190 mask = offset < M
192 x = tl.load(X, mask=mask, other=0.0).to(tl.float32)
193 cnt = (x != 0).to(tl.float32)
194 mid = tl.sum(cnt)
195 tl.store(Mid, mid)
198@libentry()
199@triton.jit
200def l0_norm_kernel_2(Mid, Out, MID_SIZE, BLOCK_MID: tl.constexpr):
201 offset = tl.arange(0, BLOCK_MID)
202 Mid = Mid + offset
203 mask = offset < MID_SIZE
204 mid = tl.load(Mid, mask=mask, other=0.0).to(tl.float32)
205 out = tl.sum(mid)
206 tl.store(Out, out)
209@libentry()
210@triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"])
211@triton.jit(do_not_specialize=["ord"])
212def v_norm_kernel(X, Out, M, N, ord, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
213 pid = tle.program_id(0).to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
214 X = X + pid * N
215 Out = Out + pid
216 row_mask = pid < M
218 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
219 for off in range(0, N, BLOCK_N):
220 cols = off + tl.arange(0, BLOCK_N)[None, :]
221 col_mask = cols < N
222 mask = row_mask and col_mask
224 a = tl.load(X + cols, mask, other=0.0).to(tl.float32)
225 _sum += pow(tl.abs(a), ord)
226 sum = tl.sum(_sum, axis=1)
227 out = pow(sum, 1 / ord)[:, None]
228 tl.store(Out, out, row_mask)
231@libentry()
232@triton.jit(do_not_specialize=["ord"])
233def l1_norm_kernel_1(X, Mid, ord, M, BLOCK_SIZE: tl.constexpr):
234 pid = tle.program_id(0).to(tl.int64)
235 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
236 X = X + offset
237 Mid = Mid + pid
238 mask = offset < M
240 x = tl.load(X, mask=mask, other=0.0).to(tl.float32)
241 mid = tl.sum(pow(tl.abs(x), ord))
242 tl.store(Mid, mid)
245@libentry()
246@triton.jit(do_not_specialize=["ord"])
247def l1_norm_kernel_2(Mid, Out, ord, MID_SIZE, BLOCK_MID: tl.constexpr):
248 offset = tl.arange(0, BLOCK_MID)
249 Mid = Mid + offset
250 mask = offset < MID_SIZE
251 mid = tl.load(Mid, mask=mask, other=0.0).to(tl.float32)
252 out = pow(tl.sum(mid), 1 / ord)
253 tl.store(Out, out)
256def vector_norm(x, ord=2, dim=None, keepdim=False, dtype=None):
257 logger.debug("GEMS VECTOR NORM")
258 if dtype is not None:
259 if isinstance(dtype, str):
260 dtype = getattr(torch, dtype)
261 elif not isinstance(dtype, torch.dtype):
262 dtype = torch.float32
263 else:
264 dtype = x.dtype
265 if dtype not in [torch.float16, torch.float32, torch.bfloat16]:
266 raise NotImplementedError(f"vector_norm not implemented for {dtype}")
268 with torch_device_fn.device(x.device):
269 if (not dim) or len(dim) == x.ndim:
270 dim = list(range(x.ndim))
271 shape = [1] * x.ndim
272 x = dim_compress(x, dim)
273 M = x.numel()
274 BLOCK_SIZE = triton.next_power_of_2(math.ceil(math.sqrt(M)))
275 MID_SIZE = triton.cdiv(M, BLOCK_SIZE)
276 BLOCK_MID = triton.next_power_of_2(MID_SIZE)
278 mid = torch.empty([MID_SIZE], dtype=dtype, device=x.device)
279 out = torch.empty(shape, dtype=dtype, device=x.device)
280 if ord == 2:
281 l2_norm_kernel_1[(MID_SIZE,)](x, mid, M, BLOCK_SIZE)
282 l2_norm_kernel_2[(1,)](mid, out, MID_SIZE, BLOCK_MID)
283 elif ord == float("inf"):
284 max_norm_kernel_1[(MID_SIZE,)](x, mid, M, BLOCK_SIZE)
285 max_norm_kernel_2[(1,)](mid, out, MID_SIZE, BLOCK_MID)
286 elif ord == -float("inf"):
287 min_norm_kernel_1[(MID_SIZE,)](x, mid, M, BLOCK_SIZE)
288 min_norm_kernel_2[(1,)](mid, out, MID_SIZE, BLOCK_MID)
289 elif ord == 0:
290 l0_norm_kernel_1[(MID_SIZE,)](x, mid, M, BLOCK_SIZE)
291 l0_norm_kernel_2[(1,)](mid, out, MID_SIZE, BLOCK_MID)
292 else:
293 l1_norm_kernel_1[(MID_SIZE,)](x, mid, ord, M, BLOCK_SIZE)
294 l1_norm_kernel_2[(1,)](mid, out, ord, MID_SIZE, BLOCK_MID)
295 else:
296 shape = list(x.shape)
297 dim = [d % x.ndim for d in dim]
298 x = dim_compress(x, dim)
299 N = 1
300 for i in dim:
301 N *= shape[i]
302 shape[i] = 1
303 M = x.numel() // N
304 out = torch.empty(shape, dtype=dtype, device=x.device)
305 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),)
306 if ord == 2:
307 l2_norm_kernel[grid](x, out, M, N)
308 elif ord == float("inf"):
309 max_norm_kernel[grid](x, out, M, N)
310 elif ord == -float("inf"):
311 min_norm_kernel[grid](x, out, M, N)
312 elif ord == 0:
313 l0_norm_kernel[grid](x, out, M, N)
314 else:
315 v_norm_kernel[grid](x, out, M, N, ord)
316 if not keepdim:
317 out = out.squeeze(dim=dim)
318 return out