Coverage for src/flag_gems/experimental_ops/sgn.py: 0%
60 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def sgn_real_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
13 x = tl.load(x_ptr + offsets, mask=mask)
15 # Create typed constants without relying on tl.full
16 zero = x - x
17 one = zero + 1
18 neg_one = -one
20 out = tl.where(x > 0, one, tl.where(x < 0, neg_one, zero))
21 tl.store(out_ptr + offsets, out, mask=mask)
24@triton.jit
25def sgn_complex_kernel(x_ri_ptr, out_ri_ptr, n_complex, BLOCK_SIZE: tl.constexpr):
26 # x_ri_ptr and out_ri_ptr are pointers to the real-imag flattened arrays:
27 # for element k: real at 2*k, imag at 2*k + 1
28 pid = tl.program_id(axis=0)
29 block_start = pid * BLOCK_SIZE
30 offsets = block_start + tl.arange(0, BLOCK_SIZE)
31 mask = offsets < n_complex
33 base = offsets * 2
34 r = tl.load(x_ri_ptr + base, mask=mask)
35 i = tl.load(x_ri_ptr + base + 1, mask=mask)
37 # typed constants
38 zero = r - r
39 one = zero + 1
41 # Compute norm and its reciprocal
42 norm = tl.sqrt(r * r + i * i)
43 inv = one / norm # Will be inf when norm == 0; handled by where below
45 nz = norm != 0
46 out_r = tl.where(nz, r * inv, zero)
47 out_i = tl.where(nz, i * inv, zero)
49 tl.store(out_ri_ptr + base, out_r, mask=mask)
50 tl.store(out_ri_ptr + base + 1, out_i, mask=mask)
53def _sgn_impl(input: torch.Tensor) -> torch.Tensor:
54 assert isinstance(input, torch.Tensor), "input must be a torch.Tensor"
55 assert input.is_cuda, "input must be on CUDA device"
56 # Compute into a contiguous result buffer
57 result = torch.empty_like(input, memory_format=torch.contiguous_format)
59 BLOCK_SIZE = 1024
60 if input.is_complex():
61 # Use real-imag views for complex types
62 in_ri = torch.view_as_real(input).contiguous().view(-1)
63 out_ri = torch.view_as_real(result).contiguous().view(-1)
64 n_complex = input.numel()
65 grid = lambda meta: (triton.cdiv(n_complex, meta["BLOCK_SIZE"]),)
66 sgn_complex_kernel[grid](
67 in_ri,
68 out_ri,
69 n_complex,
70 BLOCK_SIZE=BLOCK_SIZE,
71 )
72 else:
73 x = input.contiguous().view(-1)
74 out_flat = result.view(-1)
75 n_elements = x.numel()
76 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
77 sgn_real_kernel[grid](
78 x,
79 out_flat,
80 n_elements,
81 BLOCK_SIZE=BLOCK_SIZE,
82 )
83 return result
86def sgn(input: torch.Tensor, *, out: torch.Tensor = None):
87 """
88 Wrapper for ATen operator: ('sgn', <Autograd.disable: False>)
89 """
90 res = _sgn_impl(input)
91 if out is not None:
92 out.copy_(res)
93 return out
94 return res
97def sgn_out(input: torch.Tensor, out: torch.Tensor):
98 """
99 Wrapper for ATen operator: ('sgn.out', <Autograd.disable: False>)
100 """
101 res = _sgn_impl(input)
102 out.copy_(res)
103 return out