Coverage for src/flag_gems/experimental_ops/sign.py: 0%

37 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-22 16:54 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def sign_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 

13 x = tl.load(x_ptr + offsets, mask=mask) 

14 

15 one = tl.full([BLOCK_SIZE], 1, x.dtype) 

16 neg_one = tl.full([BLOCK_SIZE], -1, x.dtype) 

17 

18 res = tl.where(x > 0, one, tl.where(x < 0, neg_one, x)) 

19 tl.store(out_ptr + offsets, res, mask=mask) 

20 

21 

22def _launch_sign_kernel(x: torch.Tensor, out: torch.Tensor): 

23 n_elements = out.numel() 

24 if n_elements == 0: 

25 return 

26 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

27 sign_kernel[grid](x, out, n_elements, BLOCK_SIZE=1024) 

28 

29 

30def sign(x: torch.Tensor): 

31 if x.is_complex(): 

32 raise NotImplementedError( 

33 "Complex dtypes are not supported by this Triton sign kernel." 

34 ) 

35 out = torch.empty_like(x) 

36 _launch_sign_kernel(x.contiguous(), out.contiguous()) 

37 return out 

38 

39 

40def sign_out(x: torch.Tensor, out: torch.Tensor): 

41 if x.is_complex() or out.is_complex(): 

42 raise NotImplementedError( 

43 "Complex dtypes are not supported by this Triton sign kernel." 

44 ) 

45 if out.shape != x.shape: 

46 raise ValueError("Output tensor must have the same shape as input tensor.") 

47 if out.dtype != x.dtype: 

48 raise ValueError("Output tensor must have the same dtype as input tensor.") 

49 if out.device != x.device: 

50 raise ValueError("Output tensor must be on the same device as input tensor.") 

51 _launch_sign_kernel(x.contiguous(), out.contiguous()) 

52 return out