Coverage for src/flag_gems/experimental_ops/abs_.py: 0%

32 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def abs_(x_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 x = tl.load(x_ptr + offsets, mask=mask) 

13 y = tl.abs(x) 

14 tl.store(x_ptr + offsets, y, mask=mask) 

15 

16 

17# Alias the kernel before defining the Python wrapper with the same name 

18abs__kernel = abs_ 

19 

20 

21def abs_(*args, **kwargs): 

22 # Extract input tensor 

23 x = args[0] if len(args) > 0 else kwargs.get("input", None) 

24 if x is None: 

25 raise ValueError( 

26 "abs_ expects a tensor as the first positional argument or 'input' keyword argument." 

27 ) 

28 if not isinstance(x, torch.Tensor): 

29 raise TypeError("abs_ expects a torch.Tensor as input.") 

30 

31 # Handle trivial/unsupported cases 

32 if x.numel() == 0: 

33 return x 

34 if x.dtype == torch.bool: 

35 # abs on boolean is identity; nothing to do 

36 return x 

37 if x.is_complex(): 

38 raise TypeError("abs_ does not support complex tensors in-place.") 

39 

40 # Ensure tensor is on CUDA and contiguous 

41 assert x.is_cuda, "abs_ expects a CUDA tensor." 

42 assert x.is_contiguous(), "abs_ expects a contiguous tensor." 

43 

44 # Launch kernel 

45 n_elements = x.numel() 

46 BLOCK_SIZE = 1024 

47 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

48 abs__kernel[grid](x, n_elements, BLOCK_SIZE=BLOCK_SIZE) 

49 return x