Coverage for src/flag_gems/experimental_ops/negative.py: 0%
31 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def negative_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
12 x = tl.load(x_ptr + offsets, mask=mask)
13 tl.store(out_ptr + offsets, -x, mask=mask)
16def _launch_negative(x: torch.Tensor, out: torch.Tensor):
17 assert x.is_cuda and out.is_cuda, "Tensors must be on CUDA device"
18 assert x.dtype == out.dtype, "Input and output must have the same dtype"
19 assert (
20 x.numel() == out.numel()
21 ), "Input and output must have the same number of elements"
22 assert x.is_contiguous(), "Input tensor must be contiguous"
23 assert out.is_contiguous(), "Output tensor must be contiguous"
25 n_elements = x.numel()
26 if n_elements == 0:
27 return out
29 BLOCK_SIZE = 1024
30 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
31 negative_kernel[grid](x, out, n_elements, BLOCK_SIZE=BLOCK_SIZE)
32 return out
35def negative(x: torch.Tensor):
36 out = torch.empty_like(x.contiguous())
37 _launch_negative(x.contiguous(), out)
38 return out
41def negative_out(x: torch.Tensor, out: torch.Tensor):
42 _launch_negative(x, out)
43 return out