Coverage for src/flag_gems/experimental_ops/absolute_.py: 0%

34 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def absolute_(x_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 x = tl.load(x_ptr + offsets, mask=mask) 

13 x_abs = tl.abs(x) 

14 tl.store(x_ptr + offsets, x_abs, mask=mask) 

15 

16 

17# Keep a reference to the Triton kernel before redefining the wrapper with the same name 

18absolute__kernel = absolute_ 

19 

20 

21def absolute_(*args, **kwargs): 

22 x = None 

23 if len(args) >= 1: 

24 x = args[0] 

25 else: 

26 x = kwargs.get("self", None) 

27 if x is None: 

28 x = kwargs.get("input", None) 

29 if x is None or not isinstance(x, torch.Tensor): 

30 raise TypeError("absolute_ expects a torch.Tensor as the first argument") 

31 

32 # If tensor has no elements, nothing to do 

33 if x.numel() == 0: 

34 return x 

35 

36 # Dtypes supported by this Triton kernel 

37 supported_dtypes = { 

38 torch.float16, 

39 torch.bfloat16, 

40 torch.float32, 

41 torch.int8, 

42 torch.int16, 

43 torch.int32, 

44 torch.int64, 

45 torch.uint8, 

46 } 

47 

48 use_triton = x.is_cuda and x.is_contiguous() and x.dtype in supported_dtypes 

49 

50 if not use_triton: 

51 # Fallback to PyTorch implementation for unsupported cases (e.g., CPU, non-contiguous, unsupported dtype) 

52 torch.ops.aten.absolute_(x) 

53 return x 

54 

55 n_elements = x.numel() 

56 BLOCK_SIZE = 1024 

57 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

58 absolute__kernel[grid](x, n_elements, BLOCK_SIZE=BLOCK_SIZE) 

59 return x