Coverage for src/flag_gems/experimental_ops/hypot.py: 0%

76 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-16 02:02 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6def _torch_dtype_to_triton(dtype: torch.dtype): 

7 if dtype == torch.float16: 

8 return tl.float16 

9 if dtype == torch.bfloat16: 

10 return tl.bfloat16 

11 if dtype == torch.float32: 

12 return tl.float32 

13 if dtype == torch.float64: 

14 return tl.float64 

15 raise ValueError(f"Unsupported dtype for Triton conversion: {dtype}") 

16 

17 

18@triton.jit 

19def _hypot_kernel( 

20 x_ptr, 

21 y_ptr, 

22 out_ptr, 

23 n_elements, 

24 BLOCK_SIZE: tl.constexpr, 

25 OUT_DTYPE: tl.constexpr, 

26 COMPUTE_DTYPE: tl.constexpr, 

27): 

28 pid = tl.program_id(axis=0) 

29 block_start = pid * BLOCK_SIZE 

30 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

31 mask = offsets < n_elements 

32 

33 x = tl.load(x_ptr + offsets, mask=mask, other=0) 

34 y = tl.load(y_ptr + offsets, mask=mask, other=0) 

35 

36 xf = x.to(COMPUTE_DTYPE) 

37 yf = y.to(COMPUTE_DTYPE) 

38 

39 ax = tl.abs(xf) 

40 ay = tl.abs(yf) 

41 t = tl.maximum(ax, ay) 

42 m = tl.minimum(ax, ay) 

43 t_nz = tl.where(t > 0, t, 1).to(COMPUTE_DTYPE) 

44 r = m / t_nz 

45 res = tl.where(t > 0, t * tl.sqrt(1 + r * r), m) 

46 

47 out_val = res.to(OUT_DTYPE) 

48 tl.store(out_ptr + offsets, out_val, mask=mask) 

49 

50 

51def _infer_hypot_out_dtype(a: torch.Tensor, b: torch.Tensor) -> torch.dtype: 

52 if a.is_complex() or b.is_complex(): 

53 raise NotImplementedError( 

54 "Complex dtypes are not supported for hypot in this implementation." 

55 ) 

56 if a.is_floating_point() or b.is_floating_point(): 

57 return torch.result_type(a, b) 

58 # For integral/bool inputs, follow floating promotion behavior 

59 return torch.get_default_dtype() 

60 

61 

62def _launch_hypot_kernel(x: torch.Tensor, y: torch.Tensor, out: torch.Tensor): 

63 assert x.device == y.device == out.device, "All tensors must be on the same device" 

64 assert out.is_cuda, "Triton kernels require CUDA tensors" 

65 n_elements = out.numel() 

66 if n_elements == 0: 

67 return 

68 

69 BLOCK_SIZE = 1024 

70 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

71 

72 out_dtype = out.dtype 

73 if out_dtype not in (torch.float16, torch.bfloat16, torch.float32, torch.float64): 

74 raise ValueError(f"Unsupported output dtype for hypot: {out_dtype}") 

75 

76 OUT_DTYPE = _torch_dtype_to_triton(out_dtype) 

77 COMPUTE_DTYPE = tl.float64 if out_dtype == torch.float64 else tl.float32 

78 

79 _hypot_kernel[grid]( 

80 x, 

81 y, 

82 out, 

83 n_elements, 

84 BLOCK_SIZE=BLOCK_SIZE, 

85 OUT_DTYPE=OUT_DTYPE, 

86 COMPUTE_DTYPE=COMPUTE_DTYPE, 

87 ) 

88 

89 

90def hypot(a: torch.Tensor, b: torch.Tensor): 

91 # Determine output dtype and broadcasted shape 

92 out_dtype = _infer_hypot_out_dtype(a, b) 

93 device = a.device 

94 if b.device != device: 

95 raise ValueError("Input tensors must be on the same device") 

96 if device.type != "cuda": 

97 raise ValueError("This implementation requires CUDA tensors") 

98 

99 out_shape = torch.broadcast_shapes(a.shape, b.shape) 

100 out = torch.empty(out_shape, dtype=out_dtype, device=device) 

101 

102 # Prepare expanded, contiguous inputs 

103 x = a.expand(out_shape).contiguous() 

104 y = b.expand(out_shape).contiguous() 

105 

106 _launch_hypot_kernel(x, y, out) 

107 return out 

108 

109 

110def hypot_out(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor): 

111 # Validate device and shape 

112 device = out.device 

113 if (not out.is_cuda) or a.device != device or b.device != device: 

114 raise ValueError( 

115 "All tensors (a, b, out) must be CUDA tensors on the same device" 

116 ) 

117 

118 # Validate dtype 

119 if out.dtype not in (torch.float16, torch.bfloat16, torch.float32, torch.float64): 

120 raise ValueError(f"Unsupported out dtype for hypot_out: {out.dtype}") 

121 

122 # Validate/broadcast inputs to out shape 

123 target_shape = out.shape 

124 x = a.expand(target_shape).contiguous() 

125 y = b.expand(target_shape).contiguous() 

126 

127 _launch_hypot_kernel(x, y, out) 

128 return out