Coverage for src/flag_gems/ops/hypot.py: 56%
75 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
10logger = logging.getLogger(__name__)
13def _torch_dtype_to_triton(dtype: torch.dtype):
14 if dtype == torch.float16:
15 return tl.float16
16 if dtype == torch.bfloat16:
17 return tl.bfloat16
18 if dtype == torch.float32:
19 return tl.float32
20 if dtype == torch.float64:
21 return tl.float64
22 raise ValueError(f"Unsupported dtype for Triton conversion: {dtype}")
25@triton.jit
26def _hypot_kernel(
27 x_ptr,
28 y_ptr,
29 out_ptr,
30 n_elements,
31 BLOCK_SIZE: tl.constexpr,
32 OUT_DTYPE: tl.constexpr,
33 COMPUTE_DTYPE: tl.constexpr,
34):
35 pid = tl.program_id(axis=0)
36 block_start = pid * BLOCK_SIZE
37 offsets = block_start + tl.arange(0, BLOCK_SIZE)
38 mask = offsets < n_elements
40 x = tl.load(x_ptr + offsets, mask=mask, other=0)
41 y = tl.load(y_ptr + offsets, mask=mask, other=0)
43 xf = x.to(COMPUTE_DTYPE)
44 yf = y.to(COMPUTE_DTYPE)
46 ax = tl.abs(xf)
47 ay = tl.abs(yf)
48 t = tl.maximum(ax, ay)
49 m = tl.minimum(ax, ay)
50 t_nz = tl.where(t > 0, t, 1).to(COMPUTE_DTYPE)
51 r = m / t_nz
52 res = tl.where(t > 0, t * tl.sqrt(1 + r * r), m)
54 out_val = res.to(OUT_DTYPE)
55 tl.store(out_ptr + offsets, out_val, mask=mask)
58def _infer_hypot_out_dtype(a: torch.Tensor, b: torch.Tensor) -> torch.dtype:
59 if a.is_complex() or b.is_complex():
60 raise NotImplementedError(
61 "Complex dtypes are not supported for hypot in this implementation."
62 )
63 if a.is_floating_point() or b.is_floating_point():
64 return torch.result_type(a, b)
65 return torch.get_default_dtype()
68def _launch_hypot_kernel(x: torch.Tensor, y: torch.Tensor, out: torch.Tensor):
69 n_elements = out.numel()
70 if n_elements == 0:
71 return
73 BLOCK_SIZE = 1024
74 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
76 out_dtype = out.dtype
77 if out_dtype not in (torch.float16, torch.bfloat16, torch.float32, torch.float64):
78 raise ValueError(f"Unsupported output dtype for hypot: {out_dtype}")
80 OUT_DTYPE = _torch_dtype_to_triton(out_dtype)
81 COMPUTE_DTYPE = tl.float64 if out_dtype == torch.float64 else tl.float32
83 with torch_device_fn.device(out.device):
84 _hypot_kernel[grid](
85 x,
86 y,
87 out,
88 n_elements,
89 BLOCK_SIZE=BLOCK_SIZE,
90 OUT_DTYPE=OUT_DTYPE,
91 COMPUTE_DTYPE=COMPUTE_DTYPE,
92 )
95def hypot(a: torch.Tensor, b: torch.Tensor):
96 logger.debug("GEMS HYPOT")
97 out_dtype = _infer_hypot_out_dtype(a, b)
98 device = a.device
99 if b.device != device:
100 raise ValueError("Input tensors must be on the same device")
102 out_shape = torch.broadcast_shapes(a.shape, b.shape)
103 out = torch.empty(out_shape, dtype=out_dtype, device=device)
105 x = a.expand(out_shape).contiguous()
106 y = b.expand(out_shape).contiguous()
108 _launch_hypot_kernel(x, y, out)
109 return out
112def hypot_out(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor):
113 logger.debug("GEMS HYPOT_OUT")
114 if out.dtype not in (torch.float16, torch.bfloat16, torch.float32, torch.float64):
115 raise ValueError(f"Unsupported out dtype for hypot_out: {out.dtype}")
117 target_shape = out.shape
118 x = a.expand(target_shape).contiguous()
119 y = b.expand(target_shape).contiguous()
121 _launch_hypot_kernel(x, y, out)
122 return out