Coverage for src/flag_gems/experimental_ops/xlogy.py: 0%
75 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def xlogy_kernel(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
13 x = tl.load(x_ptr + offsets, mask=mask, other=0)
14 y = tl.load(y_ptr + offsets, mask=mask, other=1)
16 x_f32 = x.to(tl.float32)
17 y_f32 = y.to(tl.float32)
19 # result = where(x == 0, 0, x * log(y))
20 res = tl.where(x_f32 == 0.0, 0.0, x_f32 * tl.log(y_f32))
22 tl.store(out_ptr + offsets, res, mask=mask)
25def _ensure_tensor_on_device(obj, device, dtype):
26 if isinstance(obj, torch.Tensor):
27 return obj.to(device=device, dtype=dtype)
28 else:
29 return torch.as_tensor(obj, device=device, dtype=dtype)
32def _prepare_tensors(self, other, out=None):
33 # Determine device
34 if isinstance(self, torch.Tensor):
35 device = self.device
36 elif isinstance(other, torch.Tensor):
37 device = other.device
38 else:
39 raise ValueError("At least one of the inputs must be a Tensor.")
41 if device.type != "cuda":
42 raise ValueError("Triton kernels require CUDA tensors.")
44 # Type promotion following PyTorch semantics
45 if isinstance(self, torch.Tensor) and isinstance(other, torch.Tensor):
46 result_dtype = torch.result_type(self, other)
47 elif isinstance(self, torch.Tensor):
48 other_tmp = torch.as_tensor(other)
49 result_dtype = torch.result_type(self, other_tmp)
50 else:
51 self_tmp = torch.as_tensor(self)
52 result_dtype = torch.result_type(self_tmp, other)
54 t_self = _ensure_tensor_on_device(self, device, result_dtype)
55 t_other = _ensure_tensor_on_device(other, device, result_dtype)
57 # Broadcast to a common shape
58 b_self, b_other = torch.broadcast_tensors(t_self, t_other)
60 # Prepare output
61 if out is None:
62 out_tensor = torch.empty(b_self.shape, device=device, dtype=result_dtype)
63 return b_self.contiguous(), b_other.contiguous(), out_tensor, out_tensor
64 else:
65 if out.device != device:
66 raise ValueError("Output tensor must be on the same device as inputs.")
67 # Out dtype/shape should be able to hold result
68 expected_shape = b_self.shape
69 if out.shape != expected_shape:
70 raise ValueError(
71 f"Output tensor has shape {out.shape}, expected {expected_shape}."
72 )
73 if out.dtype != result_dtype:
74 raise ValueError(
75 f"Output tensor has dtype {out.dtype}, expected {result_dtype}."
76 )
77 # If out is contiguous, write directly; otherwise use a temporary
78 if out.is_contiguous():
79 return b_self.contiguous(), b_other.contiguous(), out, out
80 else:
81 tmp = torch.empty(expected_shape, device=device, dtype=result_dtype)
82 return b_self.contiguous(), b_other.contiguous(), tmp, out
85def _launch_xlogy(self, other, out=None):
86 x, y, dst, final_out = _prepare_tensors(self, other, out)
87 n_elements = dst.numel()
88 if n_elements == 0:
89 if final_out is not dst:
90 final_out.copy_(dst)
91 return final_out
93 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
94 xlogy_kernel[grid](x, y, dst, n_elements, BLOCK_SIZE=1024)
96 if final_out is not dst:
97 final_out.copy_(dst)
98 return final_out
101# Wrappers corresponding to ATen operator interfaces
102def xlogy_Tensor(self: torch.Tensor, other: torch.Tensor):
103 return _launch_xlogy(self, other, out=None)
106def xlogy_Scalar_Other(self: torch.Tensor, other):
107 return _launch_xlogy(self, other, out=None)
110def xlogy_Scalar_Self(self, other: torch.Tensor):
111 return _launch_xlogy(self, other, out=None)
114def xlogy_OutTensor(self: torch.Tensor, other: torch.Tensor, out: torch.Tensor):
115 return _launch_xlogy(self, other, out=out)
118def xlogy_OutScalar_Self(self, other: torch.Tensor, out: torch.Tensor):
119 return _launch_xlogy(self, other, out=out)
122def xlogy_OutScalar_Other(self: torch.Tensor, other, out: torch.Tensor):
123 return _launch_xlogy(self, other, out=out)