Coverage for src/flag_gems/experimental_ops/logaddexp.py: 0%
70 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def logaddexp_kernel(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
13 x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
14 y = tl.load(y_ptr + offsets, mask=mask, other=0.0)
16 xf32 = x.to(tl.float32)
17 yf32 = y.to(tl.float32)
19 delta = xf32 - yf32
20 adelta = tl.abs(delta)
21 m = tl.maximum(xf32, yf32)
22 res = m + tl.log(1.0 + tl.exp(-adelta))
24 out_ty = out_ptr.dtype.element_ty
25 tl.store(out_ptr + offsets, res.to(out_ty), mask=mask)
28def _ensure_cuda_tensor(obj, device, dtype):
29 if torch.is_tensor(obj):
30 return obj.to(device=device, dtype=dtype)
31 else:
32 return torch.tensor(obj, device=device, dtype=dtype)
35def _common_float_dtype(x: torch.Tensor, y: torch.Tensor):
36 dt = torch.result_type(x, y)
37 if dt not in (torch.float16, torch.bfloat16, torch.float32, torch.float64):
38 dt = torch.get_default_dtype()
39 return dt
42def _launch_logaddexp_kernel(x: torch.Tensor, y: torch.Tensor, out: torch.Tensor):
43 assert x.is_cuda and y.is_cuda and out.is_cuda, "All tensors must be on CUDA device"
44 assert (
45 x.numel() == y.numel() == out.numel()
46 ), "Input and output must have the same number of elements"
48 x_flat = x.contiguous().view(-1)
49 y_flat = y.contiguous().view(-1)
50 out_flat = out.contiguous().view(-1)
52 n_elements = out_flat.numel()
53 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
54 logaddexp_kernel[grid](x_flat, y_flat, out_flat, n_elements, BLOCK_SIZE=1024)
56 # If out was non-contiguous, copy results back into original layout
57 if not out.is_contiguous():
58 out.copy_(out_flat.view_as(out))
61def logaddexp(x, y):
62 # Determine device
63 device = None
64 if torch.is_tensor(x) and x.is_cuda:
65 device = x.device
66 if device is None and torch.is_tensor(y) and y.is_cuda:
67 device = y.device
68 if device is None:
69 raise ValueError("At least one input must be a CUDA tensor")
71 # Determine dtype
72 x_t = x if torch.is_tensor(x) else torch.tensor(x)
73 y_t = y if torch.is_tensor(y) else torch.tensor(y)
74 dtype = _common_float_dtype(x_t, y_t)
76 # Convert to device and dtype
77 x_t = _ensure_cuda_tensor(x, device, dtype)
78 y_t = _ensure_cuda_tensor(y, device, dtype)
80 # Broadcast
81 xb, yb = torch.broadcast_tensors(x_t, y_t)
83 # Allocate output
84 out = torch.empty_like(xb, dtype=dtype, device=device)
86 _launch_logaddexp_kernel(xb, yb, out)
87 return out
90def logaddexp_out(x, y, out):
91 if not torch.is_tensor(out) or not out.is_cuda:
92 raise ValueError("out must be a CUDA tensor")
94 # Determine computation device and dtype from out
95 device = out.device
96 out_dtype = out.dtype
97 if out_dtype not in (torch.float16, torch.bfloat16, torch.float32, torch.float64):
98 raise ValueError("out dtype must be a floating point type")
100 # Prepare inputs
101 x_t = _ensure_cuda_tensor(x, device, out_dtype)
102 y_t = _ensure_cuda_tensor(y, device, out_dtype)
104 # Broadcast inputs
105 xb, yb = torch.broadcast_tensors(x_t, y_t)
107 # Ensure out shape matches
108 if tuple(out.shape) != tuple(xb.shape):
109 raise ValueError(
110 f"out shape {tuple(out.shape)} does not match broadcasted shape {tuple(xb.shape)}"
111 )
113 _launch_logaddexp_kernel(xb, yb, out)
114 return out