Coverage for src/flag_gems/ops/logaddexp.py: 71%
75 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
10logger = logging.getLogger(__name__)
13@triton.jit
14def logaddexp_kernel(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
15 pid = tl.program_id(axis=0)
16 block_start = pid * BLOCK_SIZE
17 offsets = block_start + tl.arange(0, BLOCK_SIZE)
18 mask = offsets < n_elements
20 x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
21 y = tl.load(y_ptr + offsets, mask=mask, other=0.0)
23 xf32 = x.to(tl.float32)
24 yf32 = y.to(tl.float32)
26 delta = xf32 - yf32
27 adelta = tl.abs(delta)
28 m = tl.maximum(xf32, yf32)
29 res = m + tl.log(1.0 + tl.exp(-adelta))
31 out_ty = out_ptr.dtype.element_ty
32 tl.store(out_ptr + offsets, res.to(out_ty), mask=mask)
35def _ensure_cuda_tensor(obj, device, dtype):
36 if torch.is_tensor(obj):
37 return obj.to(device=device, dtype=dtype)
38 else:
39 return torch.tensor(obj, device=device, dtype=dtype)
42def _common_float_dtype(x: torch.Tensor, y: torch.Tensor):
43 dt = torch.result_type(x, y)
44 if dt not in (torch.float16, torch.bfloat16, torch.float32, torch.float64):
45 dt = torch.get_default_dtype()
46 return dt
49def _launch_logaddexp_kernel(x: torch.Tensor, y: torch.Tensor, out: torch.Tensor):
50 assert (
51 x.numel() == y.numel() == out.numel()
52 ), "Input and output must have the same number of elements"
54 x_flat = x.contiguous().view(-1)
55 y_flat = y.contiguous().view(-1)
56 out_flat = out.contiguous().view(-1)
58 n_elements = out_flat.numel()
59 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
60 with torch_device_fn.device(x.device):
61 logaddexp_kernel[grid](x_flat, y_flat, out_flat, n_elements, BLOCK_SIZE=1024)
63 # If out was non-contiguous, copy results back into original layout
64 if not out.is_contiguous():
65 out.copy_(out_flat.view_as(out))
68def logaddexp(x, y):
69 logger.debug("GEMS LOGADDEXP")
70 # Determine device
71 device = None
72 if torch.is_tensor(x) and x.is_cuda:
73 device = x.device
74 if device is None and torch.is_tensor(y) and y.is_cuda:
75 device = y.device
76 if device is None:
77 raise ValueError("At least one input must be a CUDA tensor")
79 # Determine dtype
80 x_t = x if torch.is_tensor(x) else torch.tensor(x)
81 y_t = y if torch.is_tensor(y) else torch.tensor(y)
82 dtype = _common_float_dtype(x_t, y_t)
84 # Convert to device and dtype
85 x_t = _ensure_cuda_tensor(x, device, dtype)
86 y_t = _ensure_cuda_tensor(y, device, dtype)
88 # Broadcast
89 xb, yb = torch.broadcast_tensors(x_t, y_t)
91 # Allocate output
92 out = torch.empty_like(xb, dtype=dtype, device=device)
94 _launch_logaddexp_kernel(xb, yb, out)
95 return out
98def logaddexp_out(x, y, out):
99 logger.debug("GEMS LOGADDEXP_OUT")
100 if not torch.is_tensor(out):
101 raise ValueError("out must be a tensor")
103 # Determine computation device and dtype from out
104 device = out.device
105 out_dtype = out.dtype
106 if out_dtype not in (torch.float16, torch.bfloat16, torch.float32, torch.float64):
107 raise ValueError("out dtype must be a floating point type")
109 # Prepare inputs
110 x_t = _ensure_cuda_tensor(x, device, out_dtype)
111 y_t = _ensure_cuda_tensor(y, device, out_dtype)
113 # Broadcast inputs
114 xb, yb = torch.broadcast_tensors(x_t, y_t)
116 # Ensure out shape matches
117 if tuple(out.shape) != tuple(xb.shape):
118 raise ValueError(
119 f"out shape {tuple(out.shape)} does not match broadcasted shape {tuple(xb.shape)}"
120 )
122 _launch_logaddexp_kernel(xb, yb, out)
123 return out