Coverage for src/flag_gems/experimental_ops/absolute.py: 0%

54 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-10 02:30 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def _absolute_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 x = tl.load(x_ptr + offsets, mask=mask) 

13 # Generic absolute using branchless select: works for integers and floats. 

14 zero = x * 0 

15 is_neg = x < zero 

16 y = tl.where(is_neg, -x, x) 

17 tl.store(out_ptr + offsets, y, mask=mask) 

18 

19 

20@triton.jit 

21def _absolute_complex_kernel(ri_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

22 # ri_ptr points to the real-imag parts as a contiguous float tensor of shape (..., 2) 

23 pid = tl.program_id(axis=0) 

24 block_start = pid * BLOCK_SIZE 

25 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

26 mask = offsets < n_elements 

27 base = offsets * 2 

28 re = tl.load(ri_ptr + base, mask=mask) 

29 im = tl.load(ri_ptr + base + 1, mask=mask) 

30 y = tl.sqrt(re * re + im * im) 

31 tl.store(out_ptr + offsets, y, mask=mask) 

32 

33 

34def absolute(input: torch.Tensor): 

35 x = input.contiguous() 

36 n_elements = x.numel() 

37 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

38 

39 if x.is_complex(): 

40 ri = torch.view_as_real(x).contiguous() 

41 out_dtype = x.real.dtype 

42 out = torch.empty(x.shape, dtype=out_dtype, device=x.device) 

43 _absolute_complex_kernel[grid](ri, out, n_elements, BLOCK_SIZE=1024) 

44 return out 

45 else: 

46 out = torch.empty_like(x) 

47 _absolute_kernel[grid](x, out, n_elements, BLOCK_SIZE=1024) 

48 return out 

49 

50 

51def absolute_out(input: torch.Tensor, out: torch.Tensor): 

52 x = input.contiguous() 

53 n_elements = x.numel() 

54 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

55 

56 if x.is_complex(): 

57 assert ( 

58 out.dtype == x.real.dtype 

59 ), "out dtype must be the real dtype of the complex input" 

60 assert out.shape == x.shape, "out must have the same shape as input" 

61 assert out.is_contiguous(), "out must be contiguous" 

62 ri = torch.view_as_real(x).contiguous() 

63 _absolute_complex_kernel[grid](ri, out, n_elements, BLOCK_SIZE=1024) 

64 return out 

65 else: 

66 assert out.dtype == x.dtype, "out dtype must match input dtype" 

67 assert out.shape == x.shape, "out must have the same shape as input" 

68 assert out.is_contiguous(), "out must be contiguous" 

69 _absolute_kernel[grid](x, out, n_elements, BLOCK_SIZE=1024) 

70 return out