Coverage for src/flag_gems/experimental_ops/sign.py: 0%
37 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def sign_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
13 x = tl.load(x_ptr + offsets, mask=mask)
15 one = tl.full([BLOCK_SIZE], 1, x.dtype)
16 neg_one = tl.full([BLOCK_SIZE], -1, x.dtype)
18 res = tl.where(x > 0, one, tl.where(x < 0, neg_one, x))
19 tl.store(out_ptr + offsets, res, mask=mask)
22def _launch_sign_kernel(x: torch.Tensor, out: torch.Tensor):
23 n_elements = out.numel()
24 if n_elements == 0:
25 return
26 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
27 sign_kernel[grid](x, out, n_elements, BLOCK_SIZE=1024)
30def sign(x: torch.Tensor):
31 if x.is_complex():
32 raise NotImplementedError(
33 "Complex dtypes are not supported by this Triton sign kernel."
34 )
35 out = torch.empty_like(x)
36 _launch_sign_kernel(x.contiguous(), out.contiguous())
37 return out
40def sign_out(x: torch.Tensor, out: torch.Tensor):
41 if x.is_complex() or out.is_complex():
42 raise NotImplementedError(
43 "Complex dtypes are not supported by this Triton sign kernel."
44 )
45 if out.shape != x.shape:
46 raise ValueError("Output tensor must have the same shape as input tensor.")
47 if out.dtype != x.dtype:
48 raise ValueError("Output tensor must have the same dtype as input tensor.")
49 if out.device != x.device:
50 raise ValueError("Output tensor must be on the same device as input tensor.")
51 _launch_sign_kernel(x.contiguous(), out.contiguous())
52 return out