Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/vector_norm.py: 0%
263 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
1import builtins
2import logging
4import torch
5import triton
6import triton.language as tl
8# from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import dim_compress, libentry, tl_extra_shim
11from flag_gems.utils import triton_lang_extension as tle
13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
14pow = tl_extra_shim.pow
17def heur_block_m(args):
18 return triton.next_power_of_2(triton.cdiv(args["M"], 12))
21def heur_block_n(args):
22 return builtins.min(args["N"], 8192)
25@libentry()
26# @triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"])
27@triton.heuristics(
28 {
29 "BLOCK_M": heur_block_m,
30 "BLOCK_N": heur_block_n,
31 }
32)
33@triton.jit
34def l2_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
35 pid = tle.program_id(0).to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
36 X = X + pid * N
37 Out = Out + pid
38 row_mask = pid < M
40 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
41 for off in range(0, N, BLOCK_N):
42 cols = off + tl.arange(0, BLOCK_N)[None, :]
43 col_mask = cols < N
44 mask = row_mask and col_mask
46 a = tl.load(X + cols, mask, other=0.0).to(tl.float32)
47 _sum += a * a
48 sum = tl.sum(_sum, axis=1)
50 out = tl.sqrt(sum)[:, None]
51 tl.store(Out, out, row_mask)
54@libentry()
55@triton.jit
56def l2_norm_kernel_1(
57 X, Mid, M, BLOCK_SIZE: tl.constexpr, buffer_size_limit: tl.constexpr
58):
59 pid = tle.program_id(0).to(tl.int64)
60 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
61 X = X + offset
62 Mid = Mid + pid
63 mask = offset < M
65 x = tl.load(X, mask=mask, other=0.0).to(tl.float32)
66 mid = tl.sum(x * x)
67 tl.store(Mid, mid)
70@libentry()
71@triton.jit
72def l2_norm_kernel_2(
73 Mid, Out, MID_SIZE, BLOCK_MID: tl.constexpr, buffer_size_limit: tl.constexpr
74):
75 offset = tl.arange(0, BLOCK_MID)
76 Mid = Mid + offset
77 mask = offset < MID_SIZE
78 mid = tl.load(Mid, mask=mask, other=0.0).to(tl.float32)
79 out = tl.sqrt(tl.sum(mid))
80 tl.store(Out, out)
83@libentry()
84# @triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"])
85@triton.heuristics(
86 {
87 "BLOCK_M": heur_block_m,
88 "BLOCK_N": heur_block_n,
89 }
90)
91@triton.jit
92def max_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
93 pid = tle.program_id(0).to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
94 X = X + pid * N
95 Out = Out + pid
96 row_mask = pid < M
98 _max = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
99 for off in range(0, N, BLOCK_N):
100 cols = off + tl.arange(0, BLOCK_N)[None, :]
101 col_mask = cols < N
102 mask = row_mask and col_mask
104 a = tl.load(X + cols, mask, other=0.0).to(tl.float32)
105 _max = tl.maximum(tl.abs(a), _max)
107 max = tl.max(_max, axis=1)
108 out = max[:, None]
109 tl.store(Out, out, row_mask)
112@libentry()
113@triton.jit
114def max_norm_kernel_1(
115 X, Mid, M, BLOCK_SIZE: tl.constexpr, buffer_size_limit: tl.constexpr
116):
117 pid = tle.program_id(0).to(tl.int64)
118 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
119 X = X + offset
120 Mid = Mid + pid
121 mask = offset < M
123 x = tl.load(X, mask=mask, other=0.0).to(tl.float32)
124 mid = tl.max(tl.abs(x))
125 tl.store(Mid, mid)
128@libentry()
129@triton.jit
130def max_norm_kernel_2(
131 Mid, Out, MID_SIZE, BLOCK_MID: tl.constexpr, buffer_size_limit: tl.constexpr
132):
133 offset = tl.arange(0, BLOCK_MID)
134 Mid = Mid + offset
135 mask = offset < MID_SIZE
136 mid = tl.load(Mid, mask=mask, other=0.0).to(tl.float32)
137 out = tl.max(mid)
138 tl.store(Out, out)
141@libentry()
142# @triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"])
143@triton.heuristics(
144 {
145 "BLOCK_M": heur_block_m,
146 "BLOCK_N": heur_block_n,
147 }
148)
149@triton.jit
150def min_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
151 pid = tle.program_id(0).to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
152 X = X + pid * N
153 Out = Out + pid
154 row_mask = pid < M
156 _min = tl.full([BLOCK_M, BLOCK_N], value=float("inf"), dtype=tl.float32)
157 for off in range(0, N, BLOCK_N):
158 cols = off + tl.arange(0, BLOCK_N)[None, :]
159 col_mask = cols < N
160 mask = row_mask and col_mask
162 a = tl.load(X + cols, mask, other=float("inf")).to(tl.float32)
163 _min = tl.minimum(tl.abs(a), _min)
165 min = tl.min(_min, axis=1)
166 out = min[:, None]
167 tl.store(Out, out, row_mask)
170@libentry()
171@triton.jit
172def min_norm_kernel_1(
173 X, Mid, M, BLOCK_SIZE: tl.constexpr, buffer_size_limit: tl.constexpr
174):
175 pid = tle.program_id(0).to(tl.int64)
176 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
177 X = X + offset
178 Mid = Mid + pid
179 mask = offset < M
181 x = tl.load(X, mask=mask, other=float("inf")).to(tl.float32)
182 mid = tl.min(tl.abs(x))
183 tl.store(Mid, mid)
186@libentry()
187@triton.jit
188def min_norm_kernel_2(
189 Mid, Out, MID_SIZE, BLOCK_MID: tl.constexpr, buffer_size_limit: tl.constexpr
190):
191 offset = tl.arange(0, BLOCK_MID)
192 Mid = Mid + offset
193 mask = offset < MID_SIZE
194 mid = tl.load(Mid, mask=mask, other=float("inf")).to(tl.float32)
195 out = tl.min(mid)
196 tl.store(Out, out)
199@libentry()
200# @triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"])
201@triton.heuristics(
202 {
203 "BLOCK_M": heur_block_m,
204 "BLOCK_N": heur_block_n,
205 }
206)
207@triton.jit
208def l0_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
209 pid = tle.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
210 X = X + pid * N
211 Out = Out + pid
212 row_mask = pid < M
214 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
215 for off in range(0, N, BLOCK_N):
216 cols = off + tl.arange(0, BLOCK_N)[None, :]
217 col_mask = cols < N
218 mask = row_mask and col_mask
220 a = tl.load(X + cols, mask, other=0).to(tl.float32)
221 _sum += tl.where(a != 0, 1, 0)
222 sum = tl.sum(_sum, axis=1)
223 out = sum[:, None]
224 tl.store(Out, out, row_mask)
227@libentry()
228@triton.jit
229def l0_norm_kernel_1(
230 X, Mid, M, BLOCK_SIZE: tl.constexpr, buffer_size_limit: tl.constexpr
231):
232 pid = tle.program_id(0).to(tl.int64)
233 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
234 X = X + offset
235 Mid = Mid + pid
236 mask = offset < M
238 x = tl.load(X, mask=mask, other=0.0).to(tl.float32)
239 cnt = (x != 0).to(tl.float32)
240 mid = tl.sum(cnt)
241 tl.store(Mid, mid)
244@libentry()
245@triton.jit
246def l0_norm_kernel_2(
247 Mid, Out, MID_SIZE, BLOCK_MID: tl.constexpr, buffer_size_limit: tl.constexpr
248):
249 offset = tl.arange(0, BLOCK_MID)
250 Mid = Mid + offset
251 mask = offset < MID_SIZE
252 mid = tl.load(Mid, mask=mask, other=0.0).to(tl.float32)
253 out = tl.sum(mid)
254 tl.store(Out, out)
257@libentry()
258# @triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"])
259@triton.heuristics(
260 {
261 "BLOCK_M": heur_block_m,
262 "BLOCK_N": heur_block_n,
263 }
264)
265@triton.jit(do_not_specialize=["ord"])
266def v_norm_kernel(X, Out, M, N, ord, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
267 ord = ord.to(tl.float32)
268 pid = tle.program_id(0).to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
269 X = X + pid * N
270 Out = Out + pid
271 row_mask = pid < M
273 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
274 for off in range(0, N, BLOCK_N):
275 cols = off + tl.arange(0, BLOCK_N)[None, :]
276 col_mask = cols < N
277 mask = row_mask and col_mask
279 a = tl.load(X + cols, mask, other=0.0).to(tl.float32)
280 _sum += pow(tl.abs(a), ord)
281 sum = tl.sum(_sum, axis=1)
282 out = pow(sum, 1 / ord)[:, None]
283 tl.store(Out, out, row_mask)
286@libentry()
287@triton.jit(do_not_specialize=["ord"])
288def l1_norm_kernel_1(
289 X, Mid, ord, M, BLOCK_SIZE: tl.constexpr, buffer_size_limit: tl.constexpr
290):
291 ord = ord.to(tl.float32)
292 pid = tle.program_id(0).to(tl.int64)
293 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
294 X = X + offset
295 Mid = Mid + pid
296 mask = offset < M
298 x = tl.load(X, mask=mask, other=0.0).to(tl.float32)
299 mid = tl.sum(pow(tl.abs(x), ord))
300 tl.store(Mid, mid)
303@libentry()
304@triton.jit(do_not_specialize=["ord"])
305def l1_norm_kernel_2(
306 Mid, Out, ord, MID_SIZE, BLOCK_MID: tl.constexpr, buffer_size_limit: tl.constexpr
307):
308 ord = ord.to(tl.float32)
309 offset = tl.arange(0, BLOCK_MID)
310 Mid = Mid + offset
311 mask = offset < MID_SIZE
312 mid = tl.load(Mid, mask=mask, other=0.0).to(tl.float32)
313 out = pow(tl.sum(mid), 1 / ord)
314 tl.store(Out, out)
317def vector_norm(x, ord=2, dim=None, keepdim=False, dtype=None):
318 logger.debug("GEMS VECTOR NORM")
319 if dtype is not None:
320 dtype = torch.dtype(dtype)
321 else:
322 dtype = x.dtype
323 if dtype not in [torch.float16, torch.float32, torch.bfloat16]:
324 raise NotImplementedError(f"vector_norm not implemented for {dtype}")
326 with torch_device_fn.device(x.device):
327 if (not dim) or len(dim) == x.ndim:
328 dim = list(range(x.ndim))
329 shape = [1] * x.ndim
330 x = dim_compress(x, dim)
331 M = x.numel()
332 cluster_num = 12
333 BLOCK_SIZE = min(
334 triton.next_power_of_2(triton.cdiv(M, cluster_num)),
335 int(1024 * 64 / x.element_size()),
336 )
337 MID_SIZE = triton.cdiv(M, BLOCK_SIZE)
338 BLOCK_MID = triton.next_power_of_2(MID_SIZE)
340 mid = torch.empty([MID_SIZE], dtype=dtype, device=x.device)
341 out = torch.empty(shape, dtype=dtype, device=x.device)
342 if ord == 2:
343 l2_norm_kernel_1[(MID_SIZE,)](
344 x, mid, M, BLOCK_SIZE, buffer_size_limit=2048
345 )
346 l2_norm_kernel_2[(1,)](
347 mid, out, MID_SIZE, BLOCK_MID, buffer_size_limit=2048
348 )
349 elif ord == float("inf"):
350 max_norm_kernel_1[(MID_SIZE,)](
351 x, mid, M, BLOCK_SIZE, buffer_size_limit=2048
352 )
353 max_norm_kernel_2[(1,)](
354 mid, out, MID_SIZE, BLOCK_MID, buffer_size_limit=2048
355 )
356 elif ord == -float("inf"):
357 min_norm_kernel_1[(MID_SIZE,)](
358 x, mid, M, BLOCK_SIZE, buffer_size_limit=2048
359 )
360 min_norm_kernel_2[(1,)](
361 mid, out, MID_SIZE, BLOCK_MID, buffer_size_limit=2048
362 )
363 elif ord == 0:
364 l0_norm_kernel_1[(MID_SIZE,)](
365 x, mid, M, BLOCK_SIZE, buffer_size_limit=2048
366 )
367 l0_norm_kernel_2[(1,)](
368 mid, out, MID_SIZE, BLOCK_MID, buffer_size_limit=2048
369 )
370 else:
371 l1_norm_kernel_1[(MID_SIZE,)](
372 x, mid, ord, M, BLOCK_SIZE, buffer_size_limit=2048
373 )
374 l1_norm_kernel_2[(1,)](
375 mid, out, ord, MID_SIZE, BLOCK_MID, buffer_size_limit=2048
376 )
377 else:
378 shape = list(x.shape)
379 dim = [d % x.ndim for d in dim]
380 x = dim_compress(x, dim)
381 N = 1
382 for i in dim:
383 N *= shape[i]
384 shape[i] = 1
385 M = x.numel() // N
386 out = torch.empty(shape, dtype=dtype, device=x.device)
387 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),)
388 if ord == 2:
389 l2_norm_kernel[grid](x, out, M, N)
390 elif ord == float("inf"):
391 max_norm_kernel[grid](x, out, M, N)
392 elif ord == -float("inf"):
393 min_norm_kernel[grid](x, out, M, N)
394 elif ord == 0:
395 l0_norm_kernel[grid](x, out, M, N)
396 else:
397 v_norm_kernel[grid](x, out, M, N, ord, isCloseUnrollControl=True)
398 if not keepdim:
399 out = out.squeeze(dim=dim)
400 return out