Coverage for src/flag_gems/experimental_ops/special_xlog1py.py: 0%
90 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def _xlog1py_kernel(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
12 x = tl.load(x_ptr + offsets, mask=mask)
13 y = tl.load(y_ptr + offsets, mask=mask)
14 out = x * tl.log(1.0 + y)
15 tl.store(out_ptr + offsets, out, mask=mask)
18def _ensure_cuda_tensor(t):
19 if not isinstance(t, torch.Tensor):
20 raise TypeError("Expected a torch.Tensor")
21 if t.device.type != "cuda":
22 raise ValueError("Tensors must be on CUDA device")
23 return t
26def _prepare_inputs(x, y):
27 x = _ensure_cuda_tensor(x)
28 y = _ensure_cuda_tensor(y)
29 xb, yb = torch.broadcast_tensors(x, y)
30 dtype_out = torch.result_type(xb, yb)
31 xb_fp32 = xb.to(torch.float32).contiguous()
32 yb_fp32 = yb.to(torch.float32).contiguous()
33 return xb_fp32, yb_fp32, dtype_out
36def _launch_xlog1py(x_fp32, y_fp32, out_fp32):
37 n_elements = out_fp32.numel()
38 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
39 _xlog1py_kernel[grid](x_fp32, y_fp32, out_fp32, n_elements, BLOCK_SIZE=1024)
42def special_xlog1py(x, y):
43 xb_fp32, yb_fp32, dtype_out = _prepare_inputs(x, y)
44 out_fp32 = torch.empty_like(xb_fp32)
45 _launch_xlog1py(xb_fp32, yb_fp32, out_fp32)
46 if dtype_out == torch.float32:
47 return out_fp32
48 else:
49 return out_fp32.to(dtype_out)
52def special_xlog1py_other_scalar(x, other):
53 x = _ensure_cuda_tensor(x)
54 other_tensor = torch.as_tensor(other, device=x.device, dtype=x.dtype)
55 xb_fp32, yb_fp32, dtype_out = _prepare_inputs(x, other_tensor)
56 out_fp32 = torch.empty_like(xb_fp32)
57 _launch_xlog1py(xb_fp32, yb_fp32, out_fp32)
58 if dtype_out == torch.float32:
59 return out_fp32
60 else:
61 return out_fp32.to(dtype_out)
64def special_xlog1py_self_scalar(self, other):
65 other = _ensure_cuda_tensor(other)
66 self_tensor = torch.as_tensor(self, device=other.device, dtype=other.dtype)
67 xb_fp32, yb_fp32, dtype_out = _prepare_inputs(self_tensor, other)
68 out_fp32 = torch.empty_like(xb_fp32)
69 _launch_xlog1py(xb_fp32, yb_fp32, out_fp32)
70 if dtype_out == torch.float32:
71 return out_fp32
72 else:
73 return out_fp32.to(dtype_out)
76def special_xlog1py_out(x, y, out):
77 out = _ensure_cuda_tensor(out)
78 xb_fp32, yb_fp32, dtype_out = _prepare_inputs(
79 _ensure_cuda_tensor(x), _ensure_cuda_tensor(y)
80 )
81 # Validate output shape
82 expected_shape = torch.broadcast_shapes(xb_fp32.shape, yb_fp32.shape)
83 if out.shape != expected_shape:
84 raise ValueError(f"Out tensor has shape {out.shape}, expected {expected_shape}")
85 out_fp32 = torch.empty(expected_shape, device=out.device, dtype=torch.float32)
86 _launch_xlog1py(xb_fp32, yb_fp32, out_fp32)
87 out.copy_(out_fp32.to(out.dtype))
88 return out
91def special_xlog1py_self_scalar_out(self, other, out):
92 out = _ensure_cuda_tensor(out)
93 other = _ensure_cuda_tensor(other)
94 self_tensor = torch.as_tensor(self, device=other.device, dtype=other.dtype)
95 xb_fp32, yb_fp32, dtype_out = _prepare_inputs(self_tensor, other)
96 expected_shape = torch.broadcast_shapes(xb_fp32.shape, yb_fp32.shape)
97 if out.shape != expected_shape:
98 raise ValueError(f"Out tensor has shape {out.shape}, expected {expected_shape}")
99 out_fp32 = torch.empty(expected_shape, device=out.device, dtype=torch.float32)
100 _launch_xlog1py(xb_fp32, yb_fp32, out_fp32)
101 out.copy_(out_fp32.to(out.dtype))
102 return out
105def special_xlog1py_other_scalar_out(x, other, out):
106 out = _ensure_cuda_tensor(out)
107 x = _ensure_cuda_tensor(x)
108 other_tensor = torch.as_tensor(other, device=x.device, dtype=x.dtype)
109 xb_fp32, yb_fp32, dtype_out = _prepare_inputs(x, other_tensor)
110 expected_shape = torch.broadcast_shapes(xb_fp32.shape, yb_fp32.shape)
111 if out.shape != expected_shape:
112 raise ValueError(f"Out tensor has shape {out.shape}, expected {expected_shape}")
113 out_fp32 = torch.empty(expected_shape, device=out.device, dtype=torch.float32)
114 _launch_xlog1py(xb_fp32, yb_fp32, out_fp32)
115 out.copy_(out_fp32.to(out.dtype))
116 return out