Coverage for src/flag_gems/experimental_ops/deg2rad_.py: 0%

34 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-20 02:31 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def deg2rad_(x_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 x = tl.load(x_ptr + offsets, mask=mask) 

13 factor = 0.017453292519943295 # pi / 180 

14 y = x * factor 

15 tl.store(x_ptr + offsets, y, mask=mask) 

16 

17 

18# Keep a reference to the Triton kernel before defining the Python wrapper with the same name 

19deg2rad__kernel = deg2rad_ 

20 

21 

22def deg2rad_(*args, **kwargs): 

23 # Extract the input tensor 

24 x = None 

25 if len(args) > 0: 

26 x = args[0] 

27 else: 

28 # Try common keyword names 

29 x = kwargs.get("input", kwargs.get("self", None)) 

30 if x is None: 

31 raise ValueError("deg2rad_ expects a tensor as its first argument.") 

32 

33 if not isinstance(x, torch.Tensor): 

34 raise TypeError("deg2rad_ expects a torch.Tensor as input.") 

35 

36 # Handle empty tensor quickly 

37 n_elements = x.numel() 

38 if n_elements == 0: 

39 return x 

40 

41 # If not CUDA or not contiguous or unsupported dtype, fallback to PyTorch scalar multiply in-place 

42 factor = 0.017453292519943295 # pi / 180 

43 if ( 

44 (x.device.type != "cuda") 

45 or (not x.is_contiguous()) 

46 or x.is_complex() 

47 or (not x.is_floating_point()) 

48 ): 

49 x.mul_(factor) 

50 return x 

51 

52 # Launch Triton kernel 

53 BLOCK_SIZE = 1024 

54 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

55 deg2rad__kernel[grid](x, n_elements, BLOCK_SIZE=BLOCK_SIZE) 

56 return x