Coverage for src/flag_gems/runtime/backend/_cambricon/ops/vector_norm.py: 0%
308 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import dim_compress, libentry, tl_extra_shim
11from ..utils import TOTAL_CORE_NUM, cfggen_reduce_op, prune_reduce_config
13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
14pow = tl_extra_shim.pow
17@libentry()
18@triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"])
19@triton.jit
20def l2_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
21 # Map the program id to the row of X it should compute.
22 num_prog = tl.num_programs(0)
23 task_num = tl.cdiv(M, BLOCK_M)
24 iter_num = tl.cdiv(task_num, num_prog)
25 if task_num % num_prog != 0:
26 iter_num = iter_num + 1
27 for i in range(0, iter_num):
28 pid = (i * num_prog + tl.program_id(0)) * BLOCK_M + tl.arange(0, BLOCK_M)[
29 :, None
30 ]
31 X_ptr = X + pid * N
32 Out_ptr = Out + pid
33 row_mask = pid < M
35 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
36 for off in range(0, N, BLOCK_N):
37 cols = off + tl.arange(0, BLOCK_N)[None, :]
38 col_mask = cols < N
39 mask = row_mask and col_mask
41 a = tl.load(X_ptr + cols, mask, other=0.0).to(tl.float32)
42 _sum += a * a
43 sum = tl.sum(_sum, axis=1)
45 out = tl.sqrt(sum)[:, None]
46 tl.store(Out_ptr, out, row_mask)
49@libentry()
50@triton.autotune(
51 configs=cfggen_reduce_op(),
52 key=["M"],
53 prune_configs_by={"early_config_prune": prune_reduce_config},
54 reset_to_zero=["Out"],
55)
56@triton.heuristics(
57 values={
58 "ONE_TILE_PER_CTA": lambda args: args["M"]
59 <= args["BLOCK_SIZE"] * TOTAL_CORE_NUM
60 },
61)
62@triton.jit
63def l2_norm_kernel_1(
64 X, Out, M, BLOCK_SIZE: tl.constexpr, ONE_TILE_PER_CTA: tl.constexpr
65):
66 pid = tl.program_id(0)
67 block_start = pid * BLOCK_SIZE
69 mid = 0.0
70 if ONE_TILE_PER_CTA:
71 offsets = block_start + tl.arange(0, BLOCK_SIZE)
72 mask = offsets < M
73 x = tl.load(X + offsets, mask, other=0.0).to(tl.float32)
74 mid = tl.sum(x * x)
75 else:
76 _tmp = tl.zeros([BLOCK_SIZE], tl.float32)
77 num_jobs = tl.num_programs(axis=0)
78 step = num_jobs * BLOCK_SIZE
79 for block_start_offset in range(block_start, M, step):
80 offsets = block_start_offset + tl.arange(0, BLOCK_SIZE)
81 mask = offsets < M
82 x = tl.load(X + offsets, mask, other=0.0).to(tl.float32)
83 _tmp = _tmp + x * x
84 mid = tl.sum(_tmp)
86 tl.atomic_add(Out, mid.to(tl.float32))
89@libentry()
90@triton.jit
91def l2_norm_kernel_2(
92 Out,
93):
94 out = tl.load(Out)
95 out = tl.sqrt(out)
96 tl.store(Out, out)
99@libentry()
100@triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"])
101@triton.jit
102def max_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
103 # Map the program id to the row of X it should compute.
104 num_prog = tl.num_programs(0)
105 task_num = tl.cdiv(M, BLOCK_M)
106 iter_num = tl.cdiv(task_num, num_prog)
107 if task_num % num_prog != 0:
108 iter_num = iter_num + 1
109 for i in range(0, iter_num):
110 pid = (i * num_prog + tl.program_id(0)) * BLOCK_M + tl.arange(0, BLOCK_M)[
111 :, None
112 ]
113 X_ptr = X + pid * N
114 Out_ptr = Out + pid
115 row_mask = pid < M
117 _max = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
118 for off in range(0, N, BLOCK_N):
119 cols = off + tl.arange(0, BLOCK_N)[None, :]
120 col_mask = cols < N
121 mask = row_mask and col_mask
123 a = tl.load(X_ptr + cols, mask, other=0.0).to(tl.float32)
124 _max = tl.maximum(tl.abs(a), _max)
126 max = tl.max(_max, axis=1)
127 out = max[:, None]
128 tl.store(Out_ptr, out, row_mask)
131@libentry()
132@triton.autotune(
133 configs=cfggen_reduce_op(),
134 key=["M"],
135 prune_configs_by={"early_config_prune": prune_reduce_config},
136)
137@triton.heuristics(
138 values={
139 "ONE_TILE_PER_CTA": lambda args: args["M"]
140 <= args["BLOCK_SIZE"] * TOTAL_CORE_NUM
141 },
142)
143@triton.jit
144def max_norm_kernel_1(
145 X, Out, M, BLOCK_SIZE: tl.constexpr, ONE_TILE_PER_CTA: tl.constexpr
146):
147 pid = tl.program_id(0)
148 block_start = pid * BLOCK_SIZE
150 mid = 0.0
151 if ONE_TILE_PER_CTA:
152 offsets = block_start + tl.arange(0, BLOCK_SIZE)
153 mask = offsets < M
154 x = tl.load(X + offsets, mask, other=0.0).to(tl.float32)
155 mid = tl.max(tl.abs(x))
156 else:
157 _tmp = tl.zeros([BLOCK_SIZE], tl.float32)
158 num_jobs = tl.num_programs(axis=0)
159 step = num_jobs * BLOCK_SIZE
160 for block_start_offset in range(block_start, M, step):
161 offsets = block_start_offset + tl.arange(0, BLOCK_SIZE)
162 mask = offsets < M
163 x = tl.load(X + offsets, mask, other=0.0).to(tl.float32)
164 _x = tl.abs(x)
165 _tmp = tl.where(_tmp > _x, _tmp, _x)
166 mid = tl.max(_tmp)
168 tl.atomic_max(Out, mid.to(tl.float32))
171@libentry()
172@triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"])
173@triton.jit
174def min_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
175 # Map the program id to the row of X it should compute.
176 num_prog = tl.num_programs(0)
177 task_num = tl.cdiv(M, BLOCK_M)
178 iter_num = tl.cdiv(task_num, num_prog)
179 if task_num % num_prog != 0:
180 iter_num = iter_num + 1
181 for i in range(0, iter_num):
182 pid = (i * num_prog + tl.program_id(0)) * BLOCK_M + tl.arange(0, BLOCK_M)[
183 :, None
184 ]
185 X_ptr = X + pid * N
186 Out_ptr = Out + pid
187 row_mask = pid < M
189 _min = tl.full([BLOCK_M, BLOCK_N], value=float("inf"), dtype=tl.float32)
190 for off in range(0, N, BLOCK_N):
191 cols = off + tl.arange(0, BLOCK_N)[None, :]
192 col_mask = cols < N
193 mask = row_mask and col_mask
195 a = tl.load(X_ptr + cols, mask, other=float("inf")).to(tl.float32)
196 _min = tl.minimum(tl.abs(a), _min)
198 min = tl.min(_min, axis=1)
199 out = min[:, None]
200 tl.store(Out_ptr, out, row_mask)
203@libentry()
204@triton.autotune(
205 configs=cfggen_reduce_op(),
206 key=["M"],
207 prune_configs_by={"early_config_prune": prune_reduce_config},
208)
209@triton.heuristics(
210 values={
211 "ONE_TILE_PER_CTA": lambda args: args["M"]
212 <= args["BLOCK_SIZE"] * TOTAL_CORE_NUM
213 },
214)
215@triton.jit
216def min_norm_kernel_1(
217 X, Out, M, BLOCK_SIZE: tl.constexpr, ONE_TILE_PER_CTA: tl.constexpr
218):
219 pid = tl.program_id(0)
220 block_start = pid * BLOCK_SIZE
222 if ONE_TILE_PER_CTA:
223 offsets = block_start + tl.arange(0, BLOCK_SIZE)
224 mask = offsets < M
225 x = tl.load(X + offsets, mask, other=float("inf")).to(tl.float32)
226 mid = tl.min(tl.abs(x))
227 else:
228 _tmp = tl.zeros([BLOCK_SIZE], tl.float32)
229 num_jobs = tl.num_programs(axis=0)
230 step = num_jobs * BLOCK_SIZE
231 for block_start_offset in range(block_start, M, step):
232 offsets = block_start_offset + tl.arange(0, BLOCK_SIZE)
233 mask = offsets < M
234 x = tl.load(X + offsets, mask, other=float("inf")).to(tl.float32)
235 _x = tl.abs(x)
236 _tmp = tl.where(_tmp < _x, _tmp, _x)
237 mid = tl.min(_tmp)
239 tl.atomic_min(Out, mid.to(tl.float32))
242@libentry()
243@triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"])
244@triton.jit
245def l0_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
246 # Map the program id to the row of X it should compute.
247 num_prog = tl.num_programs(0)
248 task_num = tl.cdiv(M, BLOCK_M)
249 iter_num = tl.cdiv(task_num, num_prog)
250 if task_num % num_prog != 0:
251 iter_num = iter_num + 1
252 for i in range(0, iter_num):
253 pid = (i * num_prog + tl.program_id(0)) * BLOCK_M + tl.arange(0, BLOCK_M)[
254 :, None
255 ]
256 X_ptr = X + pid * N
257 Out_ptr = Out + pid
258 row_mask = pid < M
260 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
261 for off in range(0, N, BLOCK_N):
262 cols = off + tl.arange(0, BLOCK_N)[None, :]
263 col_mask = cols < N
264 mask = row_mask and col_mask
266 a = tl.load(X_ptr + cols, mask, other=0).to(tl.float32)
267 _sum += tl.where(a != 0, 1, 0)
268 sum = tl.sum(_sum, axis=1)
269 out = sum[:, None]
270 tl.store(Out_ptr, out, row_mask)
273@libentry()
274@triton.autotune(
275 configs=cfggen_reduce_op(),
276 key=["M"],
277 prune_configs_by={"early_config_prune": prune_reduce_config},
278 reset_to_zero=["Out"],
279)
280@triton.heuristics(
281 values={
282 "ONE_TILE_PER_CTA": lambda args: args["M"]
283 <= args["BLOCK_SIZE"] * TOTAL_CORE_NUM
284 },
285)
286@triton.jit
287def l0_norm_kernel_1(
288 X, Out, M, BLOCK_SIZE: tl.constexpr, ONE_TILE_PER_CTA: tl.constexpr
289):
290 pid = tl.program_id(0)
291 block_start = pid * BLOCK_SIZE
293 if ONE_TILE_PER_CTA:
294 offsets = block_start + tl.arange(0, BLOCK_SIZE)
295 mask = offsets < M
296 x = tl.load(X + offsets, mask, other=0.0).to(tl.float32)
297 mid = tl.sum((x != 0).to(tl.float32))
298 else:
299 _tmp = tl.zeros([BLOCK_SIZE], tl.float32)
300 num_jobs = tl.num_programs(axis=0)
301 step = num_jobs * BLOCK_SIZE
302 for block_start_offset in range(block_start, M, step):
303 offsets = block_start_offset + tl.arange(0, BLOCK_SIZE)
304 mask = offsets < M
305 x = tl.load(X + offsets, mask, other=0.0).to(tl.float32)
306 _tmp = _tmp + (x != 0).to(tl.float32)
307 mid = tl.sum(_tmp)
309 tl.atomic_add(Out, mid.to(tl.float32))
312@libentry()
313@triton.autotune(configs=runtime.get_tuned_config("vector_norm"), key=["M", "N"])
314@triton.jit(do_not_specialize=["ord"])
315def v_norm_kernel(X, Out, M, N, ord, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
316 # Map the program id to the row of X it should compute.
317 num_prog = tl.num_programs(0)
318 task_num = tl.cdiv(M, BLOCK_M)
319 iter_num = tl.cdiv(task_num, num_prog)
321 for i in range(0, iter_num):
322 pid = (i * num_prog + tl.program_id(0)) * BLOCK_M + tl.arange(0, BLOCK_M)[
323 :, None
324 ]
325 X_ptr = X + pid * N
326 Out_ptr = Out + pid
327 row_mask = pid < M
329 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
330 for off in range(0, N, BLOCK_N):
331 cols = off + tl.arange(0, BLOCK_N)[None, :]
332 col_mask = cols < N
333 mask = row_mask and col_mask
335 a = tl.load(X_ptr + cols, mask, other=0.0).to(tl.float32)
336 _sum += tl.extra.mlu.libdevice.pow(tl.abs(a), ord)
337 sum = tl.sum(_sum, axis=1)
338 out = tl.extra.mlu.libdevice.pow(sum, 1 / ord)[:, None]
339 tl.store(Out_ptr, out, row_mask)
342@libentry()
343@triton.autotune(
344 configs=cfggen_reduce_op(),
345 key=["M"],
346 prune_configs_by={"early_config_prune": prune_reduce_config},
347 reset_to_zero=["Out"],
348)
349@triton.heuristics(
350 values={
351 "ONE_TILE_PER_CTA": lambda args: args["M"]
352 <= args["BLOCK_SIZE"] * TOTAL_CORE_NUM
353 },
354)
355@triton.jit(do_not_specialize=["ord"])
356def l1_norm_kernel_1(
357 X, Out, M, ord, BLOCK_SIZE: tl.constexpr, ONE_TILE_PER_CTA: tl.constexpr
358):
359 pid = tl.program_id(0)
360 block_start = pid * BLOCK_SIZE
362 mid = 0.0
363 if ONE_TILE_PER_CTA:
364 offsets = block_start + tl.arange(0, BLOCK_SIZE)
365 mask = offsets < M
366 x = tl.load(X + offsets, mask, other=0.0).to(tl.float32)
367 mid = tl.sum(pow(tl.abs(x), ord))
368 else:
369 _tmp = tl.zeros([BLOCK_SIZE], tl.float32)
370 num_jobs = tl.num_programs(axis=0)
371 step = num_jobs * BLOCK_SIZE
372 for block_start_offset in range(block_start, M, step):
373 offsets = block_start_offset + tl.arange(0, BLOCK_SIZE)
374 mask = offsets < M
375 x = tl.load(X + offsets, mask, other=0.0).to(tl.float32)
376 _tmp = _tmp + pow(tl.abs(x), ord)
377 mid = tl.sum(_tmp)
379 tl.atomic_add(Out, mid.to(tl.float32))
382@libentry()
383@triton.jit(do_not_specialize=["ord"])
384def l1_norm_kernel_2(
385 Out,
386 ord,
387):
388 out = tl.load(Out)
389 out = pow(out, 1 / ord)
390 tl.store(Out, out)
393def vector_norm(x, ord=2, dim=None, keepdim=False, dtype=None):
394 logger.debug("GEMS_CAMBRICON VECTOR NORM")
395 if dtype is not None:
396 if isinstance(dtype, str):
397 dtype = getattr(torch, dtype)
398 elif not isinstance(dtype, torch.dtype):
399 dtype = torch.float32
400 else:
401 dtype = x.dtype
402 if dtype not in [torch.float16, torch.float32, torch.bfloat16]:
403 raise NotImplementedError(f"vector_norm not implemented for {dtype}")
405 with torch_device_fn.device(x.device):
406 if (not dim) or len(dim) == x.ndim:
407 dim = list(range(x.ndim))
408 shape = [1] * x.ndim
409 x = dim_compress(x, dim)
410 M = x.numel()
412 grid = lambda meta: (
413 min(triton.cdiv(M, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),
414 )
415 out = torch.zeros(shape, dtype=torch.float, device=x.device)
416 if ord == 2:
417 l2_norm_kernel_1[grid](x, out, M)
418 l2_norm_kernel_2[(1,)](out)
419 elif ord == float("inf"):
420 max_norm_kernel_1[grid](x, out, M)
421 elif ord == -float("inf"):
422 out = torch.full(
423 shape,
424 fill_value=torch.finfo(torch.float32).max,
425 dtype=torch.float,
426 device=x.device,
427 )
428 min_norm_kernel_1[grid](x, out, M)
429 elif ord == 0:
430 l0_norm_kernel_1[grid](x, out, M)
431 else:
432 l1_norm_kernel_1[grid](x, out, M, ord)
433 l1_norm_kernel_2[(1,)](
434 out,
435 ord,
436 )
437 out = out.to(dtype)
438 else:
439 shape = list(x.shape)
440 dim = [d % x.ndim for d in dim]
441 x = dim_compress(x, dim)
442 N = 1
443 for i in dim:
444 N *= shape[i]
445 shape[i] = 1
446 M = x.numel() // N
447 out = torch.empty(shape, dtype=dtype, device=x.device)
448 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),)
449 if ord == 2:
450 l2_norm_kernel[grid](x, out, M, N)
451 elif ord == float("inf"):
452 max_norm_kernel[grid](x, out, M, N)
453 elif ord == -float("inf"):
454 min_norm_kernel[grid](x, out, M, N)
455 elif ord == 0:
456 l0_norm_kernel[grid](x, out, M, N)
457 else:
458 v_norm_kernel[grid](x, out, M, N, ord)
459 if not keepdim:
460 out = out.squeeze(dim=dim)
461 return out