Coverage for src/flag_gems/ops/absolute.py: 47%
43 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
10logger = logging.getLogger(__name__)
13@triton.jit
14def _absolute_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
15 pid = tl.program_id(axis=0)
16 block_start = pid * BLOCK_SIZE
17 offsets = block_start + tl.arange(0, BLOCK_SIZE)
18 mask = offsets < n_elements
19 x = tl.load(x_ptr + offsets, mask=mask)
20 zero = x * 0
21 is_neg = x < zero
22 y = tl.where(is_neg, -x, x)
23 tl.store(out_ptr + offsets, y, mask=mask)
26@triton.jit
27def _absolute_complex_kernel(ri_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
28 pid = tl.program_id(axis=0)
29 block_start = pid * BLOCK_SIZE
30 offsets = block_start + tl.arange(0, BLOCK_SIZE)
31 mask = offsets < n_elements
32 base = offsets * 2
33 re = tl.load(ri_ptr + base, mask=mask)
34 im = tl.load(ri_ptr + base + 1, mask=mask)
35 y = tl.sqrt(re * re + im * im)
36 tl.store(out_ptr + offsets, y, mask=mask)
39def absolute(input: torch.Tensor):
40 logger.debug("GEMS ABSOLUTE")
41 x = input.contiguous()
42 n_elements = x.numel()
43 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
45 with torch_device_fn.device(input.device):
46 if x.is_complex():
47 ri = torch.view_as_real(x).contiguous()
48 out_dtype = x.real.dtype
49 out = torch.empty(x.shape, dtype=out_dtype, device=x.device)
50 _absolute_complex_kernel[grid](ri, out, n_elements, BLOCK_SIZE=1024)
51 return out
52 else:
53 out = torch.empty_like(x)
54 _absolute_kernel[grid](x, out, n_elements, BLOCK_SIZE=1024)
55 return out