Coverage for src/flag_gems/experimental_ops/abs.py: 0%

101 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-21 14:31 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def _abs_kernel_real(in_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 x = tl.load(in_ptr + offsets, mask=mask) 

13 # For both integer and floating types: abs = x if x >= 0 else -x 

14 y = tl.where(x >= 0, x, -x) 

15 tl.store(out_ptr + offsets, y, mask=mask) 

16 

17 

18@triton.jit 

19def _abs_kernel_complex(rr_ptr, out_ptr, n_complex, BLOCK_SIZE: tl.constexpr): 

20 # rr_ptr points to interleaved real/imag scalars: [re0, im0, re1, im1, ...] 

21 pid = tl.program_id(axis=0) 

22 block_start = pid * BLOCK_SIZE 

23 offsets = block_start + tl.arange(0, BLOCK_SIZE) # complex element indices 

24 mask = offsets < n_complex 

25 base = offsets * 2 

26 re = tl.load(rr_ptr + base, mask=mask) 

27 im = tl.load(rr_ptr + base + 1, mask=mask) 

28 mag = tl.sqrt(re * re + im * im) 

29 tl.store(out_ptr + offsets, mag, mask=mask) 

30 

31 

32def _ensure_cuda_tensor(x: torch.Tensor): 

33 if not isinstance(x, torch.Tensor): 

34 raise TypeError("Input must be a torch.Tensor") 

35 if x.device.type != "cuda": 

36 raise ValueError("Tensor must be on CUDA device") 

37 return x 

38 

39 

40def _complex_abs_out_dtype(dtype: torch.dtype) -> torch.dtype: 

41 if dtype == torch.complex64: 

42 return torch.float32 

43 if dtype == torch.complex128: 

44 return torch.float64 

45 # Optional support if complex32 exists 

46 if hasattr(torch, "complex32") and dtype == getattr(torch, "complex32"): 

47 return torch.float16 

48 raise NotImplementedError(f"Unsupported complex dtype for abs: {dtype}") 

49 

50 

51def _launch_abs_real(inp: torch.Tensor, out: torch.Tensor): 

52 n_elements = out.numel() 

53 if n_elements == 0: 

54 return 

55 BLOCK = 1024 

56 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

57 _abs_kernel_real[grid](inp, out, n_elements, BLOCK_SIZE=BLOCK) 

58 

59 

60def _launch_abs_complex(inp: torch.Tensor, out: torch.Tensor): 

61 # inp is complex contiguous tensor, out is real contiguous with matching shape 

62 n_complex = inp.numel() 

63 if n_complex == 0: 

64 return 

65 # Create a real view of the interleaved storage 

66 if inp.dtype == torch.complex64: 

67 rr = inp.view(torch.float32) 

68 elif inp.dtype == torch.complex128: 

69 rr = inp.view(torch.float64) 

70 elif hasattr(torch, "complex32") and inp.dtype == getattr(torch, "complex32"): 

71 rr = inp.view(torch.float16) 

72 else: 

73 raise NotImplementedError(f"Unsupported complex dtype for abs: {inp.dtype}") 

74 BLOCK = 1024 

75 grid = lambda meta: (triton.cdiv(n_complex, meta["BLOCK_SIZE"]),) 

76 _abs_kernel_complex[grid](rr, out, n_complex, BLOCK_SIZE=BLOCK) 

77 

78 

79def abs(x: torch.Tensor): 

80 x = _ensure_cuda_tensor(x) 

81 if x.is_complex(): 

82 out_dtype = _complex_abs_out_dtype(x.dtype) 

83 out = torch.empty(x.shape, dtype=out_dtype, device=x.device) 

84 x_c = x.contiguous() 

85 out_c = out # already contiguous 

86 _launch_abs_complex(x_c, out_c) 

87 return out 

88 else: 

89 out = torch.empty_like(x) 

90 x_c = x.contiguous() 

91 out_c = out # contiguous 

92 _launch_abs_real(x_c, out_c) 

93 return out 

94 

95 

96def abs_out(x: torch.Tensor, out: torch.Tensor): 

97 x = _ensure_cuda_tensor(x) 

98 out = _ensure_cuda_tensor(out) 

99 if x.is_complex(): 

100 expected_dtype = _complex_abs_out_dtype(x.dtype) 

101 if out.dtype != expected_dtype: 

102 raise TypeError( 

103 f"abs_out: expected out.dtype={expected_dtype}, got {out.dtype}" 

104 ) 

105 if out.shape != x.shape: 

106 raise ValueError(f"abs_out: expected out.shape={x.shape}, got {out.shape}") 

107 x_c = x.contiguous() 

108 if out.is_contiguous(): 

109 out_c = out 

110 _launch_abs_complex(x_c, out_c) 

111 else: 

112 tmp = torch.empty_like(out, memory_format=torch.contiguous_format) 

113 _launch_abs_complex(x_c, tmp) 

114 out.copy_(tmp) 

115 return out 

116 else: 

117 if out.dtype != x.dtype: 

118 raise TypeError(f"abs_out: expected out.dtype={x.dtype}, got {out.dtype}") 

119 if out.shape != x.shape: 

120 raise ValueError(f"abs_out: expected out.shape={x.shape}, got {out.shape}") 

121 x_c = x.contiguous() 

122 if out.is_contiguous(): 

123 out_c = out 

124 _launch_abs_real(x_c, out_c) 

125 else: 

126 tmp = torch.empty_like(out, memory_format=torch.contiguous_format) 

127 _launch_abs_real(x_c, tmp) 

128 out.copy_(tmp) 

129 return out