Coverage for src/flag_gems/experimental_ops/rad2deg_.py: 0%

36 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-16 02:02 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def rad2deg__kernel(x_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 x = tl.load(x_ptr + offsets, mask=mask) 

13 # Convert radians to degrees: deg = rad * (180/pi) 

14 out = x * (180.0 / 3.141592653589793) 

15 tl.store(x_ptr + offsets, out, mask=mask) 

16 

17 

18def rad2deg_(*args, **kwargs): 

19 # Accept first positional argument or common keyword names 

20 x = args[0] if len(args) > 0 else kwargs.get("input", kwargs.get("self", None)) 

21 if x is None: 

22 raise ValueError("rad2deg_ expects a tensor as its first argument") 

23 if not isinstance(x, torch.Tensor): 

24 raise TypeError("rad2deg_ expects a torch.Tensor") 

25 if not x.is_floating_point(): 

26 raise TypeError( 

27 "rad2deg_ only supports floating point tensors for in-place operation" 

28 ) 

29 if not x.is_cuda: 

30 raise AssertionError("Input tensor must be on CUDA device") 

31 

32 # If non-contiguous, operate on a contiguous copy and copy back in place 

33 original = x 

34 needs_copy_back = False 

35 if not x.is_contiguous(): 

36 x = x.contiguous() 

37 needs_copy_back = True 

38 

39 n_elements = x.numel() 

40 if n_elements == 0: 

41 return original 

42 

43 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

44 rad2deg__kernel[grid](x, n_elements, BLOCK_SIZE=1024) 

45 

46 if needs_copy_back: 

47 original.copy_(x) 

48 return original 

49 return x