Coverage for src/flag_gems/experimental_ops/negative.py: 0%

31 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def negative_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 x = tl.load(x_ptr + offsets, mask=mask) 

13 tl.store(out_ptr + offsets, -x, mask=mask) 

14 

15 

16def _launch_negative(x: torch.Tensor, out: torch.Tensor): 

17 assert x.is_cuda and out.is_cuda, "Tensors must be on CUDA device" 

18 assert x.dtype == out.dtype, "Input and output must have the same dtype" 

19 assert ( 

20 x.numel() == out.numel() 

21 ), "Input and output must have the same number of elements" 

22 assert x.is_contiguous(), "Input tensor must be contiguous" 

23 assert out.is_contiguous(), "Output tensor must be contiguous" 

24 

25 n_elements = x.numel() 

26 if n_elements == 0: 

27 return out 

28 

29 BLOCK_SIZE = 1024 

30 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

31 negative_kernel[grid](x, out, n_elements, BLOCK_SIZE=BLOCK_SIZE) 

32 return out 

33 

34 

35def negative(x: torch.Tensor): 

36 out = torch.empty_like(x.contiguous()) 

37 _launch_negative(x.contiguous(), out) 

38 return out 

39 

40 

41def negative_out(x: torch.Tensor, out: torch.Tensor): 

42 _launch_negative(x, out) 

43 return out