Coverage for src/flag_gems/experimental_ops/absolute_.py: 0%
34 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def absolute_(x_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
12 x = tl.load(x_ptr + offsets, mask=mask)
13 x_abs = tl.abs(x)
14 tl.store(x_ptr + offsets, x_abs, mask=mask)
17# Keep a reference to the Triton kernel before redefining the wrapper with the same name
18absolute__kernel = absolute_
21def absolute_(*args, **kwargs):
22 x = None
23 if len(args) >= 1:
24 x = args[0]
25 else:
26 x = kwargs.get("self", None)
27 if x is None:
28 x = kwargs.get("input", None)
29 if x is None or not isinstance(x, torch.Tensor):
30 raise TypeError("absolute_ expects a torch.Tensor as the first argument")
32 # If tensor has no elements, nothing to do
33 if x.numel() == 0:
34 return x
36 # Dtypes supported by this Triton kernel
37 supported_dtypes = {
38 torch.float16,
39 torch.bfloat16,
40 torch.float32,
41 torch.int8,
42 torch.int16,
43 torch.int32,
44 torch.int64,
45 torch.uint8,
46 }
48 use_triton = x.is_cuda and x.is_contiguous() and x.dtype in supported_dtypes
50 if not use_triton:
51 # Fallback to PyTorch implementation for unsupported cases (e.g., CPU, non-contiguous, unsupported dtype)
52 torch.ops.aten.absolute_(x)
53 return x
55 n_elements = x.numel()
56 BLOCK_SIZE = 1024
57 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
58 absolute__kernel[grid](x, n_elements, BLOCK_SIZE=BLOCK_SIZE)
59 return x