Coverage for src/flag_gems/experimental_ops/deg2rad.py: 0%

75 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1import math 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7 

8@triton.jit 

9def deg2rad_kernel(x_ptr, y_ptr, n_elements, scale, BLOCK_SIZE: tl.constexpr): 

10 pid = tl.program_id(axis=0) 

11 block_start = pid * BLOCK_SIZE 

12 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

13 mask = offsets < n_elements 

14 x = tl.load(x_ptr + offsets, mask=mask) 

15 y = x * scale 

16 tl.store(y_ptr + offsets, y, mask=mask) 

17 

18 

19def _launch_deg2rad_kernel(x_contig: torch.Tensor, out_contig: torch.Tensor): 

20 assert x_contig.is_cuda and out_contig.is_cuda, "Tensors must be on CUDA device" 

21 n_elements = out_contig.numel() 

22 if n_elements == 0: 

23 return 

24 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

25 scale = math.pi / 180.0 

26 deg2rad_kernel[grid](x_contig, out_contig, n_elements, scale, BLOCK_SIZE=1024) 

27 

28 

29def deg2rad(*args, **kwargs): 

30 # Expecting a single input tensor 

31 if len(args) >= 1: 

32 x = args[0] 

33 else: 

34 x = kwargs.get("input", None) 

35 if x is None: 

36 x = kwargs.get("self", None) 

37 if x is None: 

38 raise TypeError("deg2rad expected a single input tensor") 

39 

40 if not isinstance(x, torch.Tensor): 

41 raise TypeError("deg2rad input must be a torch.Tensor") 

42 

43 if not x.is_cuda: 

44 raise AssertionError("deg2rad expects CUDA tensors") 

45 

46 if x.is_complex(): 

47 raise NotImplementedError( 

48 "Complex tensors are not supported in this Triton implementation" 

49 ) 

50 

51 # Determine result dtype: keep floating dtype; promote integer/bool to float32 

52 if x.is_floating_point(): 

53 result_dtype = x.dtype 

54 else: 

55 result_dtype = torch.float32 

56 

57 x_cast = x.to(result_dtype) 

58 x_contig = x_cast.contiguous() 

59 out = torch.empty_like(x_contig, dtype=result_dtype, device=x.device) 

60 

61 _launch_deg2rad_kernel(x_contig.view(-1), out.view(-1)) 

62 

63 return out.view_as(x) 

64 

65 

66def deg2rad_out(*args, **kwargs): 

67 # Expecting input tensor and out tensor, either as positional or keyword args 

68 x = None 

69 out = None 

70 if len(args) >= 1: 

71 x = args[0] 

72 else: 

73 x = kwargs.get("input", None) or kwargs.get("self", None) 

74 if len(args) >= 2: 

75 out = args[1] 

76 else: 

77 out = kwargs.get("out", None) 

78 

79 if x is None or out is None: 

80 raise TypeError("deg2rad_out expected (input, out) tensors") 

81 

82 if not isinstance(x, torch.Tensor) or not isinstance(out, torch.Tensor): 

83 raise TypeError("deg2rad_out arguments must be torch.Tensor") 

84 

85 if not x.is_cuda or not out.is_cuda: 

86 raise AssertionError("deg2rad_out expects CUDA tensors") 

87 

88 if x.is_complex() or out.is_complex(): 

89 raise NotImplementedError( 

90 "Complex tensors are not supported in this Triton implementation" 

91 ) 

92 

93 if out.numel() != x.numel(): 

94 raise RuntimeError( 

95 "deg2rad_out: 'out' must have the same number of elements as 'input'" 

96 ) 

97 

98 if out.device != x.device: 

99 raise RuntimeError("deg2rad_out: 'out' must be on the same device as 'input'") 

100 

101 # Compute in the dtype of 'out' to match out-variant semantics 

102 x_cast = x.to(out.dtype) 

103 x_contig = x_cast.contiguous() 

104 

105 if out.is_contiguous(): 

106 out_contig = out 

107 need_copy_back = False 

108 else: 

109 out_contig = torch.empty_like(out, memory_format=torch.contiguous_format) 

110 need_copy_back = True 

111 

112 _launch_deg2rad_kernel(x_contig.view(-1), out_contig.view(-1)) 

113 

114 if need_copy_back: 

115 out.copy_(out_contig) 

116 

117 return out