Coverage for src/flag_gems/experimental_ops/logaddexp2.py: 0%
68 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def logaddexp2_kernel(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
10 mask = offs < n_elements
12 # Load inputs and upcast to fp32 for numerics
13 x = tl.load(x_ptr + offs, mask=mask, other=0).to(tl.float32)
14 y = tl.load(y_ptr + offs, mask=mask, other=0).to(tl.float32)
16 # Numerically-stable logaddexp2:
17 # logaddexp2(x, y) = m + log2(1 + 2^(-|x - y|)), where m = max(x, y)
18 ln2 = 0.6931471805599453
19 inv_ln2 = 1.4426950408889634
21 d = tl.abs(x - y)
22 m = tl.maximum(x, y)
23 t = tl.exp(-d * ln2) # 2^(-|x-y|) = exp(-(abs(x-y)) * ln(2))
24 res = m + tl.log(1.0 + t) * inv_ln2 # log2(1 + t) = ln(1+t) / ln(2)
26 # Store; Triton will cast to the dtype of out_ptr as needed
27 tl.store(out_ptr + offs, res, mask=mask)
30def _broadcast_and_check(x, y):
31 # Convert scalars to tensors
32 if not isinstance(x, torch.Tensor):
33 x = torch.as_tensor(x)
34 if not isinstance(y, torch.Tensor):
35 y = torch.as_tensor(y)
36 # Broadcast
37 bx, by = torch.broadcast_tensors(x, y)
38 return bx, by
41def _choose_out_dtype(x, y, out=None):
42 if out is not None:
43 return out.dtype
44 # Prefer highest precision floating dtype present; else default dtype
45 float_priority = [torch.float64, torch.float32, torch.bfloat16, torch.float16]
46 for dt in float_priority:
47 if x.dtype == dt or y.dtype == dt:
48 return dt
49 # If none are floating, use default dtype
50 return torch.get_default_dtype()
53def _launch_kernel(xc, yc, outc):
54 n_elements = outc.numel()
55 if n_elements == 0:
56 return
57 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
58 logaddexp2_kernel[grid](xc, yc, outc, n_elements, BLOCK_SIZE=1024)
61def logaddexp2(x, y):
62 bx, by = _broadcast_and_check(x, y)
64 # Fallback for unsupported devices or complex dtype
65 if (
66 bx.device.type != "cuda"
67 or by.device.type != "cuda"
68 or bx.device != by.device
69 or bx.is_complex()
70 or by.is_complex()
71 ):
72 return torch.ops.aten.logaddexp2(bx, by)
74 out_dtype = _choose_out_dtype(bx, by, out=None)
75 out = torch.empty(bx.shape, device=bx.device, dtype=out_dtype)
77 # Ensure contiguous 1D buffers for the kernel
78 xc = bx.contiguous().view(-1)
79 yc = by.contiguous().view(-1)
80 outc = out.contiguous().view(-1)
82 _launch_kernel(xc, yc, outc)
83 return out
86def logaddexp2_out(x, y, out):
87 if out is None:
88 raise ValueError("out tensor must be provided for logaddexp2_out")
90 bx, by = _broadcast_and_check(x, y)
92 # Fallback for unsupported devices or complex dtype
93 if (
94 out.device.type != "cuda"
95 or bx.device.type != "cuda"
96 or by.device.type != "cuda"
97 or not (bx.device == by.device == out.device)
98 or bx.is_complex()
99 or by.is_complex()
100 or out.is_complex()
101 ):
102 # Use PyTorch implementation for unsupported cases
103 return torch.ops.aten.logaddexp2.out(bx, by, out=out)
105 # Shape and dtype checks
106 if out.shape != bx.shape:
107 raise ValueError(
108 f"out tensor has shape {out.shape}, expected {bx.shape} from broadcast"
109 )
110 # We allow dtype differences; computation will write to out's dtype
112 # Prepare contiguous buffers
113 xc = bx.contiguous().view(-1)
114 yc = by.contiguous().view(-1)
116 if out.is_contiguous():
117 outc = out.view(-1)
118 _launch_kernel(xc, yc, outc)
119 return out
120 else:
121 # Compute into a temporary contiguous buffer then copy back
122 tmp = torch.empty_like(out, memory_format=torch.contiguous_format)
123 outc = tmp.view(-1)
124 _launch_kernel(xc, yc, outc)
125 out.copy_(tmp)
126 return out