Coverage for src/flag_gems/experimental_ops/deg2rad.py: 0%
75 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
1import math
3import torch
4import triton
5import triton.language as tl
8@triton.jit
9def deg2rad_kernel(x_ptr, y_ptr, n_elements, scale, BLOCK_SIZE: tl.constexpr):
10 pid = tl.program_id(axis=0)
11 block_start = pid * BLOCK_SIZE
12 offsets = block_start + tl.arange(0, BLOCK_SIZE)
13 mask = offsets < n_elements
14 x = tl.load(x_ptr + offsets, mask=mask)
15 y = x * scale
16 tl.store(y_ptr + offsets, y, mask=mask)
19def _launch_deg2rad_kernel(x_contig: torch.Tensor, out_contig: torch.Tensor):
20 assert x_contig.is_cuda and out_contig.is_cuda, "Tensors must be on CUDA device"
21 n_elements = out_contig.numel()
22 if n_elements == 0:
23 return
24 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
25 scale = math.pi / 180.0
26 deg2rad_kernel[grid](x_contig, out_contig, n_elements, scale, BLOCK_SIZE=1024)
29def deg2rad(*args, **kwargs):
30 # Expecting a single input tensor
31 if len(args) >= 1:
32 x = args[0]
33 else:
34 x = kwargs.get("input", None)
35 if x is None:
36 x = kwargs.get("self", None)
37 if x is None:
38 raise TypeError("deg2rad expected a single input tensor")
40 if not isinstance(x, torch.Tensor):
41 raise TypeError("deg2rad input must be a torch.Tensor")
43 if not x.is_cuda:
44 raise AssertionError("deg2rad expects CUDA tensors")
46 if x.is_complex():
47 raise NotImplementedError(
48 "Complex tensors are not supported in this Triton implementation"
49 )
51 # Determine result dtype: keep floating dtype; promote integer/bool to float32
52 if x.is_floating_point():
53 result_dtype = x.dtype
54 else:
55 result_dtype = torch.float32
57 x_cast = x.to(result_dtype)
58 x_contig = x_cast.contiguous()
59 out = torch.empty_like(x_contig, dtype=result_dtype, device=x.device)
61 _launch_deg2rad_kernel(x_contig.view(-1), out.view(-1))
63 return out.view_as(x)
66def deg2rad_out(*args, **kwargs):
67 # Expecting input tensor and out tensor, either as positional or keyword args
68 x = None
69 out = None
70 if len(args) >= 1:
71 x = args[0]
72 else:
73 x = kwargs.get("input", None) or kwargs.get("self", None)
74 if len(args) >= 2:
75 out = args[1]
76 else:
77 out = kwargs.get("out", None)
79 if x is None or out is None:
80 raise TypeError("deg2rad_out expected (input, out) tensors")
82 if not isinstance(x, torch.Tensor) or not isinstance(out, torch.Tensor):
83 raise TypeError("deg2rad_out arguments must be torch.Tensor")
85 if not x.is_cuda or not out.is_cuda:
86 raise AssertionError("deg2rad_out expects CUDA tensors")
88 if x.is_complex() or out.is_complex():
89 raise NotImplementedError(
90 "Complex tensors are not supported in this Triton implementation"
91 )
93 if out.numel() != x.numel():
94 raise RuntimeError(
95 "deg2rad_out: 'out' must have the same number of elements as 'input'"
96 )
98 if out.device != x.device:
99 raise RuntimeError("deg2rad_out: 'out' must be on the same device as 'input'")
101 # Compute in the dtype of 'out' to match out-variant semantics
102 x_cast = x.to(out.dtype)
103 x_contig = x_cast.contiguous()
105 if out.is_contiguous():
106 out_contig = out
107 need_copy_back = False
108 else:
109 out_contig = torch.empty_like(out, memory_format=torch.contiguous_format)
110 need_copy_back = True
112 _launch_deg2rad_kernel(x_contig.view(-1), out_contig.view(-1))
114 if need_copy_back:
115 out.copy_(out_contig)
117 return out