Coverage for src/flag_gems/ops/absolute.py: 47%

43 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-25 02:48 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@triton.jit 

14def _absolute_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

15 pid = tl.program_id(axis=0) 

16 block_start = pid * BLOCK_SIZE 

17 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

18 mask = offsets < n_elements 

19 x = tl.load(x_ptr + offsets, mask=mask) 

20 zero = x * 0 

21 is_neg = x < zero 

22 y = tl.where(is_neg, -x, x) 

23 tl.store(out_ptr + offsets, y, mask=mask) 

24 

25 

26@triton.jit 

27def _absolute_complex_kernel(ri_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

28 pid = tl.program_id(axis=0) 

29 block_start = pid * BLOCK_SIZE 

30 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

31 mask = offsets < n_elements 

32 base = offsets * 2 

33 re = tl.load(ri_ptr + base, mask=mask) 

34 im = tl.load(ri_ptr + base + 1, mask=mask) 

35 y = tl.sqrt(re * re + im * im) 

36 tl.store(out_ptr + offsets, y, mask=mask) 

37 

38 

39def absolute(input: torch.Tensor): 

40 logger.debug("GEMS ABSOLUTE") 

41 x = input.contiguous() 

42 n_elements = x.numel() 

43 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

44 

45 with torch_device_fn.device(input.device): 

46 if x.is_complex(): 

47 ri = torch.view_as_real(x).contiguous() 

48 out_dtype = x.real.dtype 

49 out = torch.empty(x.shape, dtype=out_dtype, device=x.device) 

50 _absolute_complex_kernel[grid](ri, out, n_elements, BLOCK_SIZE=1024) 

51 return out 

52 else: 

53 out = torch.empty_like(x) 

54 _absolute_kernel[grid](x, out, n_elements, BLOCK_SIZE=1024) 

55 return out