Coverage for src/flag_gems/experimental_ops/rad2deg_.py: 0%
36 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def rad2deg__kernel(x_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
12 x = tl.load(x_ptr + offsets, mask=mask)
13 # Convert radians to degrees: deg = rad * (180/pi)
14 out = x * (180.0 / 3.141592653589793)
15 tl.store(x_ptr + offsets, out, mask=mask)
18def rad2deg_(*args, **kwargs):
19 # Accept first positional argument or common keyword names
20 x = args[0] if len(args) > 0 else kwargs.get("input", kwargs.get("self", None))
21 if x is None:
22 raise ValueError("rad2deg_ expects a tensor as its first argument")
23 if not isinstance(x, torch.Tensor):
24 raise TypeError("rad2deg_ expects a torch.Tensor")
25 if not x.is_floating_point():
26 raise TypeError(
27 "rad2deg_ only supports floating point tensors for in-place operation"
28 )
29 if not x.is_cuda:
30 raise AssertionError("Input tensor must be on CUDA device")
32 # If non-contiguous, operate on a contiguous copy and copy back in place
33 original = x
34 needs_copy_back = False
35 if not x.is_contiguous():
36 x = x.contiguous()
37 needs_copy_back = True
39 n_elements = x.numel()
40 if n_elements == 0:
41 return original
43 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
44 rad2deg__kernel[grid](x, n_elements, BLOCK_SIZE=1024)
46 if needs_copy_back:
47 original.copy_(x)
48 return original
49 return x