Coverage for src/flag_gems/experimental_ops/sgn_.py: 0%

37 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-13 10:08 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def sgn_(x_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 

13 x = tl.load(x_ptr + offsets, mask=mask, other=0) 

14 

15 pos = x > 0 

16 neg = x < 0 

17 res = pos.to(x.dtype) - neg.to(x.dtype) 

18 

19 # Propagate NaNs for floating types. For integer types, (x != x) is always false. 

20 is_nan = x != x 

21 res = tl.where(is_nan, x, res) 

22 

23 tl.store(x_ptr + offsets, res, mask=mask) 

24 

25 

26sgn___kernel = sgn_ 

27 

28 

29def sgn_(*args, **kwargs): 

30 # Expect a single tensor argument (in-place op) 

31 x = None 

32 if len(args) == 1 and isinstance(args[0], torch.Tensor): 

33 x = args[0] 

34 elif "input" in kwargs and isinstance(kwargs["input"], torch.Tensor): 

35 x = kwargs["input"] 

36 elif "self" in kwargs and isinstance(kwargs["self"], torch.Tensor): 

37 x = kwargs["self"] 

38 

39 if x is None: 

40 raise TypeError("sgn_ expects a single Tensor argument") 

41 

42 # Fallback for unsupported cases 

43 unsupported = (not x.is_cuda) or (not x.is_contiguous()) or x.is_complex() 

44 supported_dtypes = { 

45 torch.float16, 

46 torch.float32, 

47 torch.float64, 

48 torch.bfloat16, 

49 torch.int8, 

50 torch.int16, 

51 torch.int32, 

52 torch.int64, 

53 torch.uint8, 

54 } 

55 if unsupported or x.dtype not in supported_dtypes: 

56 return torch.ops.aten.sgn_(x) 

57 

58 n_elements = x.numel() 

59 if n_elements == 0: 

60 return x 

61 

62 grid = lambda META: (triton.cdiv(n_elements, META["BLOCK_SIZE"]),) 

63 sgn___kernel[grid](x, n_elements, BLOCK_SIZE=1024) 

64 return x