Coverage for src/flag_gems/experimental_ops/abs.py: 0%
101 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def _abs_kernel_real(in_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
12 x = tl.load(in_ptr + offsets, mask=mask)
13 # For both integer and floating types: abs = x if x >= 0 else -x
14 y = tl.where(x >= 0, x, -x)
15 tl.store(out_ptr + offsets, y, mask=mask)
18@triton.jit
19def _abs_kernel_complex(rr_ptr, out_ptr, n_complex, BLOCK_SIZE: tl.constexpr):
20 # rr_ptr points to interleaved real/imag scalars: [re0, im0, re1, im1, ...]
21 pid = tl.program_id(axis=0)
22 block_start = pid * BLOCK_SIZE
23 offsets = block_start + tl.arange(0, BLOCK_SIZE) # complex element indices
24 mask = offsets < n_complex
25 base = offsets * 2
26 re = tl.load(rr_ptr + base, mask=mask)
27 im = tl.load(rr_ptr + base + 1, mask=mask)
28 mag = tl.sqrt(re * re + im * im)
29 tl.store(out_ptr + offsets, mag, mask=mask)
32def _ensure_cuda_tensor(x: torch.Tensor):
33 if not isinstance(x, torch.Tensor):
34 raise TypeError("Input must be a torch.Tensor")
35 if x.device.type != "cuda":
36 raise ValueError("Tensor must be on CUDA device")
37 return x
40def _complex_abs_out_dtype(dtype: torch.dtype) -> torch.dtype:
41 if dtype == torch.complex64:
42 return torch.float32
43 if dtype == torch.complex128:
44 return torch.float64
45 # Optional support if complex32 exists
46 if hasattr(torch, "complex32") and dtype == getattr(torch, "complex32"):
47 return torch.float16
48 raise NotImplementedError(f"Unsupported complex dtype for abs: {dtype}")
51def _launch_abs_real(inp: torch.Tensor, out: torch.Tensor):
52 n_elements = out.numel()
53 if n_elements == 0:
54 return
55 BLOCK = 1024
56 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
57 _abs_kernel_real[grid](inp, out, n_elements, BLOCK_SIZE=BLOCK)
60def _launch_abs_complex(inp: torch.Tensor, out: torch.Tensor):
61 # inp is complex contiguous tensor, out is real contiguous with matching shape
62 n_complex = inp.numel()
63 if n_complex == 0:
64 return
65 # Create a real view of the interleaved storage
66 if inp.dtype == torch.complex64:
67 rr = inp.view(torch.float32)
68 elif inp.dtype == torch.complex128:
69 rr = inp.view(torch.float64)
70 elif hasattr(torch, "complex32") and inp.dtype == getattr(torch, "complex32"):
71 rr = inp.view(torch.float16)
72 else:
73 raise NotImplementedError(f"Unsupported complex dtype for abs: {inp.dtype}")
74 BLOCK = 1024
75 grid = lambda meta: (triton.cdiv(n_complex, meta["BLOCK_SIZE"]),)
76 _abs_kernel_complex[grid](rr, out, n_complex, BLOCK_SIZE=BLOCK)
79def abs(x: torch.Tensor):
80 x = _ensure_cuda_tensor(x)
81 if x.is_complex():
82 out_dtype = _complex_abs_out_dtype(x.dtype)
83 out = torch.empty(x.shape, dtype=out_dtype, device=x.device)
84 x_c = x.contiguous()
85 out_c = out # already contiguous
86 _launch_abs_complex(x_c, out_c)
87 return out
88 else:
89 out = torch.empty_like(x)
90 x_c = x.contiguous()
91 out_c = out # contiguous
92 _launch_abs_real(x_c, out_c)
93 return out
96def abs_out(x: torch.Tensor, out: torch.Tensor):
97 x = _ensure_cuda_tensor(x)
98 out = _ensure_cuda_tensor(out)
99 if x.is_complex():
100 expected_dtype = _complex_abs_out_dtype(x.dtype)
101 if out.dtype != expected_dtype:
102 raise TypeError(
103 f"abs_out: expected out.dtype={expected_dtype}, got {out.dtype}"
104 )
105 if out.shape != x.shape:
106 raise ValueError(f"abs_out: expected out.shape={x.shape}, got {out.shape}")
107 x_c = x.contiguous()
108 if out.is_contiguous():
109 out_c = out
110 _launch_abs_complex(x_c, out_c)
111 else:
112 tmp = torch.empty_like(out, memory_format=torch.contiguous_format)
113 _launch_abs_complex(x_c, tmp)
114 out.copy_(tmp)
115 return out
116 else:
117 if out.dtype != x.dtype:
118 raise TypeError(f"abs_out: expected out.dtype={x.dtype}, got {out.dtype}")
119 if out.shape != x.shape:
120 raise ValueError(f"abs_out: expected out.shape={x.shape}, got {out.shape}")
121 x_c = x.contiguous()
122 if out.is_contiguous():
123 out_c = out
124 _launch_abs_real(x_c, out_c)
125 else:
126 tmp = torch.empty_like(out, memory_format=torch.contiguous_format)
127 _launch_abs_real(x_c, tmp)
128 out.copy_(tmp)
129 return out