Coverage for src/flag_gems/ops/hypot.py: 56%

75 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-25 02:48 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13def _torch_dtype_to_triton(dtype: torch.dtype): 

14 if dtype == torch.float16: 

15 return tl.float16 

16 if dtype == torch.bfloat16: 

17 return tl.bfloat16 

18 if dtype == torch.float32: 

19 return tl.float32 

20 if dtype == torch.float64: 

21 return tl.float64 

22 raise ValueError(f"Unsupported dtype for Triton conversion: {dtype}") 

23 

24 

25@triton.jit 

26def _hypot_kernel( 

27 x_ptr, 

28 y_ptr, 

29 out_ptr, 

30 n_elements, 

31 BLOCK_SIZE: tl.constexpr, 

32 OUT_DTYPE: tl.constexpr, 

33 COMPUTE_DTYPE: tl.constexpr, 

34): 

35 pid = tl.program_id(axis=0) 

36 block_start = pid * BLOCK_SIZE 

37 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

38 mask = offsets < n_elements 

39 

40 x = tl.load(x_ptr + offsets, mask=mask, other=0) 

41 y = tl.load(y_ptr + offsets, mask=mask, other=0) 

42 

43 xf = x.to(COMPUTE_DTYPE) 

44 yf = y.to(COMPUTE_DTYPE) 

45 

46 ax = tl.abs(xf) 

47 ay = tl.abs(yf) 

48 t = tl.maximum(ax, ay) 

49 m = tl.minimum(ax, ay) 

50 t_nz = tl.where(t > 0, t, 1).to(COMPUTE_DTYPE) 

51 r = m / t_nz 

52 res = tl.where(t > 0, t * tl.sqrt(1 + r * r), m) 

53 

54 out_val = res.to(OUT_DTYPE) 

55 tl.store(out_ptr + offsets, out_val, mask=mask) 

56 

57 

58def _infer_hypot_out_dtype(a: torch.Tensor, b: torch.Tensor) -> torch.dtype: 

59 if a.is_complex() or b.is_complex(): 

60 raise NotImplementedError( 

61 "Complex dtypes are not supported for hypot in this implementation." 

62 ) 

63 if a.is_floating_point() or b.is_floating_point(): 

64 return torch.result_type(a, b) 

65 return torch.get_default_dtype() 

66 

67 

68def _launch_hypot_kernel(x: torch.Tensor, y: torch.Tensor, out: torch.Tensor): 

69 n_elements = out.numel() 

70 if n_elements == 0: 

71 return 

72 

73 BLOCK_SIZE = 1024 

74 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

75 

76 out_dtype = out.dtype 

77 if out_dtype not in (torch.float16, torch.bfloat16, torch.float32, torch.float64): 

78 raise ValueError(f"Unsupported output dtype for hypot: {out_dtype}") 

79 

80 OUT_DTYPE = _torch_dtype_to_triton(out_dtype) 

81 COMPUTE_DTYPE = tl.float64 if out_dtype == torch.float64 else tl.float32 

82 

83 with torch_device_fn.device(out.device): 

84 _hypot_kernel[grid]( 

85 x, 

86 y, 

87 out, 

88 n_elements, 

89 BLOCK_SIZE=BLOCK_SIZE, 

90 OUT_DTYPE=OUT_DTYPE, 

91 COMPUTE_DTYPE=COMPUTE_DTYPE, 

92 ) 

93 

94 

95def hypot(a: torch.Tensor, b: torch.Tensor): 

96 logger.debug("GEMS HYPOT") 

97 out_dtype = _infer_hypot_out_dtype(a, b) 

98 device = a.device 

99 if b.device != device: 

100 raise ValueError("Input tensors must be on the same device") 

101 

102 out_shape = torch.broadcast_shapes(a.shape, b.shape) 

103 out = torch.empty(out_shape, dtype=out_dtype, device=device) 

104 

105 x = a.expand(out_shape).contiguous() 

106 y = b.expand(out_shape).contiguous() 

107 

108 _launch_hypot_kernel(x, y, out) 

109 return out 

110 

111 

112def hypot_out(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor): 

113 logger.debug("GEMS HYPOT_OUT") 

114 if out.dtype not in (torch.float16, torch.bfloat16, torch.float32, torch.float64): 

115 raise ValueError(f"Unsupported out dtype for hypot_out: {out.dtype}") 

116 

117 target_shape = out.shape 

118 x = a.expand(target_shape).contiguous() 

119 y = b.expand(target_shape).contiguous() 

120 

121 _launch_hypot_kernel(x, y, out) 

122 return out