Coverage for src/flag_gems/experimental_ops/sgn_.py: 0%
37 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def sgn_(x_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
13 x = tl.load(x_ptr + offsets, mask=mask, other=0)
15 pos = x > 0
16 neg = x < 0
17 res = pos.to(x.dtype) - neg.to(x.dtype)
19 # Propagate NaNs for floating types. For integer types, (x != x) is always false.
20 is_nan = x != x
21 res = tl.where(is_nan, x, res)
23 tl.store(x_ptr + offsets, res, mask=mask)
26sgn___kernel = sgn_
29def sgn_(*args, **kwargs):
30 # Expect a single tensor argument (in-place op)
31 x = None
32 if len(args) == 1 and isinstance(args[0], torch.Tensor):
33 x = args[0]
34 elif "input" in kwargs and isinstance(kwargs["input"], torch.Tensor):
35 x = kwargs["input"]
36 elif "self" in kwargs and isinstance(kwargs["self"], torch.Tensor):
37 x = kwargs["self"]
39 if x is None:
40 raise TypeError("sgn_ expects a single Tensor argument")
42 # Fallback for unsupported cases
43 unsupported = (not x.is_cuda) or (not x.is_contiguous()) or x.is_complex()
44 supported_dtypes = {
45 torch.float16,
46 torch.float32,
47 torch.float64,
48 torch.bfloat16,
49 torch.int8,
50 torch.int16,
51 torch.int32,
52 torch.int64,
53 torch.uint8,
54 }
55 if unsupported or x.dtype not in supported_dtypes:
56 return torch.ops.aten.sgn_(x)
58 n_elements = x.numel()
59 if n_elements == 0:
60 return x
62 grid = lambda META: (triton.cdiv(n_elements, META["BLOCK_SIZE"]),)
63 sgn___kernel[grid](x, n_elements, BLOCK_SIZE=1024)
64 return x