Coverage for src/flag_gems/runtime/backend/_ascend/ops/vector_norm.py: 0%
267 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import dim_compress, libentry, tl_extra_shim
11from flag_gems.utils import triton_lang_extension as tle
13logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
16try:
17 import torch_npu # noqa: F401
19 pow = tl.extra.ascend.libdevice.pow
20except: # noqa: E722
21 pow = tl_extra_shim.pow
24@libentry()
25@triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"])
26@triton.jit
27def l2_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
28 pid = tle.program_id(0).to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
29 X = X + pid * N
30 Out = Out + pid
31 row_mask = pid < M
33 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
34 for off in range(0, N, BLOCK_N):
35 cols = off + tl.arange(0, BLOCK_N)[None, :]
36 col_mask = cols < N
37 mask = row_mask and col_mask
39 a = tl.load(X + cols, mask, other=0.0).to(tl.float32)
40 _sum += a * a
41 sum = tl.sum(_sum, axis=1)
43 out = tl.sqrt(sum)[:, None]
44 tl.store(Out, out, row_mask)
47@libentry()
48@triton.jit
49def l2_norm_kernel_1(X, Mid, M, BLOCK_SIZE: tl.constexpr, BLOCK_SIZE_SUB: tl.constexpr):
50 pid = tl.program_id(0).to(tl.int64)
52 total_sum = 0.0
54 for off in range(0, BLOCK_SIZE, BLOCK_SIZE_SUB):
55 offsets = pid * BLOCK_SIZE + off + tl.arange(0, BLOCK_SIZE_SUB)
56 mask = offsets < M
57 x = tl.load(X + offsets, mask=mask, other=0.0).to(tl.float32)
58 total_sum += tl.sum(x * x)
60 tl.store(Mid + pid, total_sum)
63@libentry()
64@triton.jit
65def l2_norm_kernel_2(
66 Mid, Out, MID_SIZE, BLOCK_MID: tl.constexpr, BLOCK_MID_SUB: tl.constexpr
67):
68 pid = tl.program_id(0).to(tl.int64)
70 total_sum = 0.0
72 for off in range(0, MID_SIZE, BLOCK_MID_SUB):
73 offsets = pid * MID_SIZE + off + tl.arange(0, BLOCK_MID_SUB)
74 mask = offsets < MID_SIZE
75 x = tl.load(Mid + offsets, mask=mask, other=0.0).to(tl.float32)
76 total_sum += tl.sum(x)
77 out = tl.sqrt(total_sum)
78 tl.store(Out, out)
81@libentry()
82@triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"])
83@triton.jit
84def max_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
85 pid = tle.program_id(0).to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
86 X = X + pid * N
87 Out = Out + pid
88 row_mask = pid < M
90 _max = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
91 for off in range(0, N, BLOCK_N):
92 cols = off + tl.arange(0, BLOCK_N)[None, :]
93 col_mask = cols < N
94 mask = row_mask and col_mask
96 a = tl.load(X + cols, mask, other=0.0).to(tl.float32)
97 _max = tl.maximum(tl.abs(a), _max)
99 max = tl.max(_max, axis=1)
100 out = max[:, None]
101 tl.store(Out, out, row_mask)
104@libentry()
105@triton.jit
106def max_norm_kernel_1(X, Mid, M, BLOCK_SIZE: tl.constexpr):
107 pid = tle.program_id(0).to(tl.int64)
108 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
109 X = X + offset
110 Mid = Mid + pid
111 mask = offset < M
113 x = tl.load(X, mask=mask, other=0.0).to(tl.float32)
114 mid = tl.max(tl.abs(x))
115 tl.store(Mid, mid)
118@libentry()
119@triton.jit
120def max_norm_kernel_2(Mid, Out, MID_SIZE, BLOCK_MID: tl.constexpr):
121 offset = tl.arange(0, BLOCK_MID)
122 Mid = Mid + offset
123 mask = offset < MID_SIZE
124 mid = tl.load(Mid, mask=mask, other=0.0).to(tl.float32)
125 out = tl.max(mid)
126 tl.store(Out, out)
129@libentry()
130@triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"])
131@triton.jit
132def min_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
133 pid = tle.program_id(0).to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
134 X = X + pid * N
135 Out = Out + pid
136 row_mask = pid < M
138 _min = tl.full([BLOCK_M, BLOCK_N], value=float("inf"), dtype=tl.float32)
139 for off in range(0, N, BLOCK_N):
140 cols = off + tl.arange(0, BLOCK_N)[None, :]
141 col_mask = cols < N
142 mask = row_mask and col_mask
144 a = tl.load(X + cols, mask, other=float("inf")).to(tl.float32)
145 _min = tl.minimum(tl.abs(a), _min)
147 min = tl.min(_min, axis=1)
148 out = min[:, None]
149 tl.store(Out, out, row_mask)
152@libentry()
153@triton.jit
154def min_norm_kernel_1(X, Mid, M, BLOCK_SIZE: tl.constexpr):
155 pid = tle.program_id(0).to(tl.int64)
156 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
157 X = X + offset
158 Mid = Mid + pid
159 mask = offset < M
161 x = tl.load(X, mask=mask, other=float("inf")).to(tl.float32)
162 mid = tl.min(tl.abs(x))
163 tl.store(Mid, mid)
166@libentry()
167@triton.jit
168def min_norm_kernel_2(Mid, Out, MID_SIZE, BLOCK_MID: tl.constexpr):
169 offset = tl.arange(0, BLOCK_MID)
170 Mid = Mid + offset
171 mask = offset < MID_SIZE
172 mid = tl.load(Mid, mask=mask, other=float("inf")).to(tl.float32)
173 out = tl.min(mid)
174 tl.store(Out, out)
177@libentry()
178@triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"])
179@triton.jit
180def l0_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
181 pid = tle.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
182 X = X + pid * N
183 Out = Out + pid
184 row_mask = pid < M
186 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
187 for off in range(0, N, BLOCK_N):
188 cols = off + tl.arange(0, BLOCK_N)[None, :]
189 col_mask = cols < N
190 mask = row_mask and col_mask
192 a = tl.load(X + cols, mask, other=0).to(tl.float32)
193 _sum += tl.where(a != 0, 1, 0)
194 sum = tl.sum(_sum, axis=1)
195 out = sum[:, None]
196 tl.store(Out, out, row_mask)
199@libentry()
200@triton.jit
201def l0_norm_kernel_1(X, Mid, M, BLOCK_SIZE: tl.constexpr):
202 pid = tle.program_id(0).to(tl.int64)
203 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
204 X = X + offset
205 Mid = Mid + pid
206 mask = offset < M
208 x = tl.load(X, mask=mask, other=0.0).to(tl.float32)
209 cnt = (x != 0).to(tl.float32)
210 mid = tl.sum(cnt)
211 tl.store(Mid, mid)
214@libentry()
215@triton.jit
216def l0_norm_kernel_2(Mid, Out, MID_SIZE, BLOCK_MID: tl.constexpr):
217 offset = tl.arange(0, BLOCK_MID)
218 Mid = Mid + offset
219 mask = offset < MID_SIZE
220 mid = tl.load(Mid, mask=mask, other=0.0).to(tl.float32)
221 out = tl.sum(mid)
222 tl.store(Out, out)
225@libentry()
226@triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"])
227@triton.jit(do_not_specialize=["ord"])
228def v_norm_kernel(X, Out, M, N, ord, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
229 pid = tle.program_id(0).to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
230 X = X + pid * N
231 Out = Out + pid
232 row_mask = pid < M
234 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
235 for off in range(0, N, BLOCK_N):
236 cols = off + tl.arange(0, BLOCK_N)[None, :]
237 col_mask = cols < N
238 mask = row_mask and col_mask
240 a = tl.load(X + cols, mask, other=0.0).to(tl.float32)
241 _sum += pow(tl.abs(a), ord)
242 sum = tl.sum(_sum, axis=1)
243 out = pow(sum, 1 / ord)[:, None]
244 tl.store(Out, out, row_mask)
247@libentry()
248@triton.jit(do_not_specialize=["ord"])
249def l1_norm_kernel_1(X, Mid, ord, M, BLOCK_SIZE: tl.constexpr):
250 pid = tle.program_id(0).to(tl.int64)
251 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
252 X = X + offset
253 Mid = Mid + pid
254 mask = offset < M
256 x = tl.load(X, mask=mask, other=0.0).to(tl.float32)
257 mid = tl.sum(pow(tl.abs(x), ord))
258 tl.store(Mid, mid)
261@libentry()
262@triton.jit(do_not_specialize=["ord"])
263def l1_norm_kernel_2(Mid, Out, ord, MID_SIZE, BLOCK_MID: tl.constexpr):
264 offset = tl.arange(0, BLOCK_MID)
265 Mid = Mid + offset
266 mask = offset < MID_SIZE
267 mid = tl.load(Mid, mask=mask, other=0.0).to(tl.float32)
268 out = pow(tl.sum(mid), 1 / ord)
269 tl.store(Out, out)
272def vector_norm(x, ord=2, dim=None, keepdim=False, dtype=None):
273 logger.debug("GEMS_ASCEND VECTOR NORM")
274 if dtype is not None:
275 dtype = torch.dtype(dtype)
276 else:
277 dtype = x.dtype
278 if dtype not in [torch.float16, torch.float32, torch.bfloat16]:
279 raise NotImplementedError(f"vector_norm not implemented for {dtype}")
281 with torch_device_fn.device(x.device):
282 if (not dim) or len(dim) == x.ndim:
283 dim = list(range(x.ndim))
284 shape = [1] * x.ndim
285 x = dim_compress(x, dim)
286 M = x.numel()
288 MAX_BLOCK_SIZE = 32768
289 BLOCK_SIZE = min(
290 triton.next_power_of_2(math.ceil(math.sqrt(M))), MAX_BLOCK_SIZE
291 )
292 MID_SIZE = triton.cdiv(M, BLOCK_SIZE)
293 BLOCK_MID = triton.next_power_of_2(MID_SIZE)
294 if BLOCK_MID >= 512:
295 BLOCK_MID_SUB = 512
296 else:
297 BLOCK_MID_SUB = 1
298 mid = torch.empty([MID_SIZE], dtype=dtype, device=x.device)
299 out = torch.empty(shape, dtype=dtype, device=x.device)
300 if ord == 2:
301 l2_norm_kernel_1[(MID_SIZE,)](x, mid, M, BLOCK_SIZE, BLOCK_MID_SUB)
302 l2_norm_kernel_2[(1,)](mid, out, MID_SIZE, BLOCK_MID, BLOCK_MID_SUB)
303 elif ord == float("inf"):
304 max_norm_kernel_1[(MID_SIZE,)](x, mid, M, BLOCK_SIZE)
305 max_norm_kernel_2[(1,)](mid, out, MID_SIZE, BLOCK_MID)
306 elif ord == -float("inf"):
307 min_norm_kernel_1[(MID_SIZE,)](x, mid, M, BLOCK_SIZE)
308 min_norm_kernel_2[(1,)](mid, out, MID_SIZE, BLOCK_MID)
309 elif ord == 0:
310 l0_norm_kernel_1[(MID_SIZE,)](x, mid, M, BLOCK_SIZE)
311 l0_norm_kernel_2[(1,)](mid, out, MID_SIZE, BLOCK_MID)
312 else:
313 l1_norm_kernel_1[(MID_SIZE,)](x, mid, ord, M, BLOCK_SIZE)
314 l1_norm_kernel_2[(1,)](mid, out, ord, MID_SIZE, BLOCK_MID)
315 else:
316 shape = list(x.shape)
317 dim = [d % x.ndim for d in dim]
318 x = dim_compress(x, dim)
319 N = 1
320 for i in dim:
321 N *= shape[i]
322 shape[i] = 1
323 M = x.numel() // N
324 out = torch.empty(shape, dtype=dtype, device=x.device)
325 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),)
326 if ord == 2:
327 l2_norm_kernel[grid](x, out, M, N)
328 elif ord == float("inf"):
329 max_norm_kernel[grid](x, out, M, N)
330 elif ord == -float("inf"):
331 min_norm_kernel[grid](x, out, M, N)
332 elif ord == 0:
333 l0_norm_kernel[grid](x, out, M, N)
334 else:
335 v_norm_kernel[grid](x, out, M, N, ord)
336 if not keepdim:
337 out = out.squeeze(dim=dim)
338 return out