Coverage for src/flag_gems/experimental_ops/sgn.py: 0%

60 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-22 16:54 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def sgn_real_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 

13 x = tl.load(x_ptr + offsets, mask=mask) 

14 

15 # Create typed constants without relying on tl.full 

16 zero = x - x 

17 one = zero + 1 

18 neg_one = -one 

19 

20 out = tl.where(x > 0, one, tl.where(x < 0, neg_one, zero)) 

21 tl.store(out_ptr + offsets, out, mask=mask) 

22 

23 

24@triton.jit 

25def sgn_complex_kernel(x_ri_ptr, out_ri_ptr, n_complex, BLOCK_SIZE: tl.constexpr): 

26 # x_ri_ptr and out_ri_ptr are pointers to the real-imag flattened arrays: 

27 # for element k: real at 2*k, imag at 2*k + 1 

28 pid = tl.program_id(axis=0) 

29 block_start = pid * BLOCK_SIZE 

30 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

31 mask = offsets < n_complex 

32 

33 base = offsets * 2 

34 r = tl.load(x_ri_ptr + base, mask=mask) 

35 i = tl.load(x_ri_ptr + base + 1, mask=mask) 

36 

37 # typed constants 

38 zero = r - r 

39 one = zero + 1 

40 

41 # Compute norm and its reciprocal 

42 norm = tl.sqrt(r * r + i * i) 

43 inv = one / norm # Will be inf when norm == 0; handled by where below 

44 

45 nz = norm != 0 

46 out_r = tl.where(nz, r * inv, zero) 

47 out_i = tl.where(nz, i * inv, zero) 

48 

49 tl.store(out_ri_ptr + base, out_r, mask=mask) 

50 tl.store(out_ri_ptr + base + 1, out_i, mask=mask) 

51 

52 

53def _sgn_impl(input: torch.Tensor) -> torch.Tensor: 

54 assert isinstance(input, torch.Tensor), "input must be a torch.Tensor" 

55 assert input.is_cuda, "input must be on CUDA device" 

56 # Compute into a contiguous result buffer 

57 result = torch.empty_like(input, memory_format=torch.contiguous_format) 

58 

59 BLOCK_SIZE = 1024 

60 if input.is_complex(): 

61 # Use real-imag views for complex types 

62 in_ri = torch.view_as_real(input).contiguous().view(-1) 

63 out_ri = torch.view_as_real(result).contiguous().view(-1) 

64 n_complex = input.numel() 

65 grid = lambda meta: (triton.cdiv(n_complex, meta["BLOCK_SIZE"]),) 

66 sgn_complex_kernel[grid]( 

67 in_ri, 

68 out_ri, 

69 n_complex, 

70 BLOCK_SIZE=BLOCK_SIZE, 

71 ) 

72 else: 

73 x = input.contiguous().view(-1) 

74 out_flat = result.view(-1) 

75 n_elements = x.numel() 

76 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

77 sgn_real_kernel[grid]( 

78 x, 

79 out_flat, 

80 n_elements, 

81 BLOCK_SIZE=BLOCK_SIZE, 

82 ) 

83 return result 

84 

85 

86def sgn(input: torch.Tensor, *, out: torch.Tensor = None): 

87 """ 

88 Wrapper for ATen operator: ('sgn', <Autograd.disable: False>) 

89 """ 

90 res = _sgn_impl(input) 

91 if out is not None: 

92 out.copy_(res) 

93 return out 

94 return res 

95 

96 

97def sgn_out(input: torch.Tensor, out: torch.Tensor): 

98 """ 

99 Wrapper for ATen operator: ('sgn.out', <Autograd.disable: False>) 

100 """ 

101 res = _sgn_impl(input) 

102 out.copy_(res) 

103 return out