Coverage for src/flag_gems/experimental_ops/hypot.py: 0%
76 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1import torch
2import triton
3import triton.language as tl
6def _torch_dtype_to_triton(dtype: torch.dtype):
7 if dtype == torch.float16:
8 return tl.float16
9 if dtype == torch.bfloat16:
10 return tl.bfloat16
11 if dtype == torch.float32:
12 return tl.float32
13 if dtype == torch.float64:
14 return tl.float64
15 raise ValueError(f"Unsupported dtype for Triton conversion: {dtype}")
18@triton.jit
19def _hypot_kernel(
20 x_ptr,
21 y_ptr,
22 out_ptr,
23 n_elements,
24 BLOCK_SIZE: tl.constexpr,
25 OUT_DTYPE: tl.constexpr,
26 COMPUTE_DTYPE: tl.constexpr,
27):
28 pid = tl.program_id(axis=0)
29 block_start = pid * BLOCK_SIZE
30 offsets = block_start + tl.arange(0, BLOCK_SIZE)
31 mask = offsets < n_elements
33 x = tl.load(x_ptr + offsets, mask=mask, other=0)
34 y = tl.load(y_ptr + offsets, mask=mask, other=0)
36 xf = x.to(COMPUTE_DTYPE)
37 yf = y.to(COMPUTE_DTYPE)
39 ax = tl.abs(xf)
40 ay = tl.abs(yf)
41 t = tl.maximum(ax, ay)
42 m = tl.minimum(ax, ay)
43 t_nz = tl.where(t > 0, t, 1).to(COMPUTE_DTYPE)
44 r = m / t_nz
45 res = tl.where(t > 0, t * tl.sqrt(1 + r * r), m)
47 out_val = res.to(OUT_DTYPE)
48 tl.store(out_ptr + offsets, out_val, mask=mask)
51def _infer_hypot_out_dtype(a: torch.Tensor, b: torch.Tensor) -> torch.dtype:
52 if a.is_complex() or b.is_complex():
53 raise NotImplementedError(
54 "Complex dtypes are not supported for hypot in this implementation."
55 )
56 if a.is_floating_point() or b.is_floating_point():
57 return torch.result_type(a, b)
58 # For integral/bool inputs, follow floating promotion behavior
59 return torch.get_default_dtype()
62def _launch_hypot_kernel(x: torch.Tensor, y: torch.Tensor, out: torch.Tensor):
63 assert x.device == y.device == out.device, "All tensors must be on the same device"
64 assert out.is_cuda, "Triton kernels require CUDA tensors"
65 n_elements = out.numel()
66 if n_elements == 0:
67 return
69 BLOCK_SIZE = 1024
70 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
72 out_dtype = out.dtype
73 if out_dtype not in (torch.float16, torch.bfloat16, torch.float32, torch.float64):
74 raise ValueError(f"Unsupported output dtype for hypot: {out_dtype}")
76 OUT_DTYPE = _torch_dtype_to_triton(out_dtype)
77 COMPUTE_DTYPE = tl.float64 if out_dtype == torch.float64 else tl.float32
79 _hypot_kernel[grid](
80 x,
81 y,
82 out,
83 n_elements,
84 BLOCK_SIZE=BLOCK_SIZE,
85 OUT_DTYPE=OUT_DTYPE,
86 COMPUTE_DTYPE=COMPUTE_DTYPE,
87 )
90def hypot(a: torch.Tensor, b: torch.Tensor):
91 # Determine output dtype and broadcasted shape
92 out_dtype = _infer_hypot_out_dtype(a, b)
93 device = a.device
94 if b.device != device:
95 raise ValueError("Input tensors must be on the same device")
96 if device.type != "cuda":
97 raise ValueError("This implementation requires CUDA tensors")
99 out_shape = torch.broadcast_shapes(a.shape, b.shape)
100 out = torch.empty(out_shape, dtype=out_dtype, device=device)
102 # Prepare expanded, contiguous inputs
103 x = a.expand(out_shape).contiguous()
104 y = b.expand(out_shape).contiguous()
106 _launch_hypot_kernel(x, y, out)
107 return out
110def hypot_out(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor):
111 # Validate device and shape
112 device = out.device
113 if (not out.is_cuda) or a.device != device or b.device != device:
114 raise ValueError(
115 "All tensors (a, b, out) must be CUDA tensors on the same device"
116 )
118 # Validate dtype
119 if out.dtype not in (torch.float16, torch.bfloat16, torch.float32, torch.float64):
120 raise ValueError(f"Unsupported out dtype for hypot_out: {out.dtype}")
122 # Validate/broadcast inputs to out shape
123 target_shape = out.shape
124 x = a.expand(target_shape).contiguous()
125 y = b.expand(target_shape).contiguous()
127 _launch_hypot_kernel(x, y, out)
128 return out