Coverage for src/flag_gems/ops/sgn_.py: 56%

41 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@triton.jit 

14def sgn_kernel_(x_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

15 pid = tl.program_id(axis=0) 

16 block_start = pid * BLOCK_SIZE 

17 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

18 mask = offsets < n_elements 

19 

20 x = tl.load(x_ptr + offsets, mask=mask, other=0) 

21 

22 pos = x > 0 

23 neg = x < 0 

24 res = pos.to(x.dtype) - neg.to(x.dtype) 

25 

26 # Propagate NaNs for floating types 

27 is_nan = x != x 

28 res = tl.where(is_nan, x, res) 

29 

30 tl.store(x_ptr + offsets, res, mask=mask) 

31 

32 

33def sgn_(*args, **kwargs): 

34 logger.debug("GEMS SGN_") 

35 x = None 

36 if len(args) == 1 and isinstance(args[0], torch.Tensor): 

37 x = args[0] 

38 elif "input" in kwargs and isinstance(kwargs["input"], torch.Tensor): 

39 x = kwargs["input"] 

40 elif "self" in kwargs and isinstance(kwargs["self"], torch.Tensor): 

41 x = kwargs["self"] 

42 

43 if x is None: 

44 raise TypeError("sgn_ expects a single Tensor argument") 

45 

46 unsupported = (not x.is_contiguous()) or x.is_complex() 

47 supported_dtypes = { 

48 torch.float16, 

49 torch.float32, 

50 torch.float64, 

51 torch.bfloat16, 

52 torch.int8, 

53 torch.int16, 

54 torch.int32, 

55 torch.int64, 

56 torch.uint8, 

57 } 

58 if unsupported or x.dtype not in supported_dtypes: 

59 return torch.ops.aten.sgn_(x) 

60 

61 n_elements = x.numel() 

62 if n_elements == 0: 

63 return x 

64 

65 grid = lambda META: (triton.cdiv(n_elements, META["BLOCK_SIZE"]),) 

66 with torch_device_fn.device(x.device): 

67 sgn_kernel_[grid](x, n_elements, BLOCK_SIZE=1024) 

68 return x