Coverage for src/flag_gems/experimental_ops/hypot_.py: 0%
56 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def hypot_(
8 x_ptr, # Pointer to first input (will be output if in-place).
9 y_ptr, # Pointer to second input (broadcasted/contiguous).
10 out_ptr, # Pointer to output buffer.
11 n_elements, # Number of elements to process.
12 BLOCK_SIZE: tl.constexpr,
13):
14 pid = tl.program_id(axis=0)
15 block_start = pid * BLOCK_SIZE
16 offsets = block_start + tl.arange(0, BLOCK_SIZE)
17 mask = offsets < n_elements
19 x = tl.load(x_ptr + offsets, mask=mask)
20 y = tl.load(y_ptr + offsets, mask=mask)
22 x32 = x.to(tl.float32)
23 y32 = y.to(tl.float32)
24 out32 = tl.sqrt(x32 * x32 + y32 * y32)
26 out_cast = out32.to(x.dtype)
27 tl.store(out_ptr + offsets, out_cast, mask=mask)
30_hypot_kernel = hypot_
33def hypot_(*args, **kwargs):
34 # Extract arguments similar to torch.ops.aten.hypot_(self, other)
35 x = None
36 other = None
37 if len(args) >= 1:
38 x = args[0]
39 if len(args) >= 2:
40 other = args[1]
41 if x is None:
42 x = kwargs.get("input", kwargs.get("self", None))
43 if other is None:
44 other = kwargs.get("other", None)
46 if x is None or other is None:
47 raise TypeError("hypot_ expects two arguments: self and other")
49 if not isinstance(x, torch.Tensor):
50 raise TypeError("self must be a torch.Tensor")
51 if not x.is_cuda:
52 raise ValueError("hypot_ Triton kernel only supports CUDA tensors")
54 device = x.device
56 # Prepare 'other' on the same device and dtype as x (in-place ops keep dtype)
57 if isinstance(other, torch.Tensor):
58 other_t = other.to(device)
59 else:
60 other_t = torch.tensor(other, device=device)
62 # In-place must keep dtype of x; cast other to x.dtype
63 if other_t.dtype != x.dtype:
64 other_t = other_t.to(x.dtype)
66 # Broadcast other to x's shape
67 try:
68 other_b = torch.broadcast_to(other_t, x.shape)
69 except Exception:
70 other_b = torch.broadcast_tensors(other_t, x)[0]
72 # Ensure contiguous buffers for kernel
73 x_c = x if x.is_contiguous() else x.contiguous()
74 other_c = other_b if other_b.is_contiguous() else other_b.contiguous()
76 n_elements = x.numel()
77 if n_elements == 0:
78 return x
80 # If x is contiguous, write directly in-place into x; otherwise write to temp and copy back.
81 out_buf = x_c if x.is_contiguous() else torch.empty_like(x_c)
83 BLOCK_SIZE = 1024
84 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
86 _hypot_kernel[grid](x_c, other_c, out_buf, n_elements, BLOCK_SIZE=BLOCK_SIZE)
88 if not x.is_contiguous():
89 x.copy_(out_buf)
91 return x