Coverage for src/flag_gems/experimental_ops/absolute.py: 0%
54 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def _absolute_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
12 x = tl.load(x_ptr + offsets, mask=mask)
13 # Generic absolute using branchless select: works for integers and floats.
14 zero = x * 0
15 is_neg = x < zero
16 y = tl.where(is_neg, -x, x)
17 tl.store(out_ptr + offsets, y, mask=mask)
20@triton.jit
21def _absolute_complex_kernel(ri_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
22 # ri_ptr points to the real-imag parts as a contiguous float tensor of shape (..., 2)
23 pid = tl.program_id(axis=0)
24 block_start = pid * BLOCK_SIZE
25 offsets = block_start + tl.arange(0, BLOCK_SIZE)
26 mask = offsets < n_elements
27 base = offsets * 2
28 re = tl.load(ri_ptr + base, mask=mask)
29 im = tl.load(ri_ptr + base + 1, mask=mask)
30 y = tl.sqrt(re * re + im * im)
31 tl.store(out_ptr + offsets, y, mask=mask)
34def absolute(input: torch.Tensor):
35 x = input.contiguous()
36 n_elements = x.numel()
37 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
39 if x.is_complex():
40 ri = torch.view_as_real(x).contiguous()
41 out_dtype = x.real.dtype
42 out = torch.empty(x.shape, dtype=out_dtype, device=x.device)
43 _absolute_complex_kernel[grid](ri, out, n_elements, BLOCK_SIZE=1024)
44 return out
45 else:
46 out = torch.empty_like(x)
47 _absolute_kernel[grid](x, out, n_elements, BLOCK_SIZE=1024)
48 return out
51def absolute_out(input: torch.Tensor, out: torch.Tensor):
52 x = input.contiguous()
53 n_elements = x.numel()
54 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
56 if x.is_complex():
57 assert (
58 out.dtype == x.real.dtype
59 ), "out dtype must be the real dtype of the complex input"
60 assert out.shape == x.shape, "out must have the same shape as input"
61 assert out.is_contiguous(), "out must be contiguous"
62 ri = torch.view_as_real(x).contiguous()
63 _absolute_complex_kernel[grid](ri, out, n_elements, BLOCK_SIZE=1024)
64 return out
65 else:
66 assert out.dtype == x.dtype, "out dtype must match input dtype"
67 assert out.shape == x.shape, "out must have the same shape as input"
68 assert out.is_contiguous(), "out must be contiguous"
69 _absolute_kernel[grid](x, out, n_elements, BLOCK_SIZE=1024)
70 return out