Coverage for src/flag_gems/experimental_ops/logical_xor_.py: 0%
41 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def logical_xor_(a_ptr, b_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
13 a = tl.load(a_ptr + offsets, mask=mask)
14 b = tl.load(b_ptr + offsets, mask=mask)
16 a_bool = a != 0
17 b_bool = b != 0
18 out = a_bool ^ b_bool
20 tl.store(a_ptr + offsets, out, mask=mask)
23# Preserve reference to the Triton kernel before defining the Python wrapper with the same name.
24logical_xor___triton_kernel = logical_xor_
27def logical_xor_(*args, **kwargs):
28 # Parse inputs: expect (self, other)
29 if len(args) >= 2:
30 self, other = args[0], args[1]
31 else:
32 self = kwargs.get("input", kwargs.get("self", None))
33 other = kwargs.get("other", None)
35 if not isinstance(self, torch.Tensor):
36 raise TypeError("logical_xor_: first argument must be a torch.Tensor")
37 if self.dtype is not torch.bool:
38 raise RuntimeError(
39 "logical_xor_: in-place logical operations require self to have dtype torch.bool"
40 )
42 if not self.is_cuda:
43 raise RuntimeError("logical_xor_: tensor must be on CUDA device")
45 # Prepare 'other' as tensor on same device
46 if isinstance(other, torch.Tensor):
47 other_t = other.to(device=self.device)
48 else:
49 # Create scalar tensor with dtype matching self (bool)
50 other_t = torch.tensor(other, device=self.device, dtype=self.dtype)
52 # Broadcast 'other' to self's shape and make it contiguous for simple indexing
53 try:
54 other_bc = torch.broadcast_to(other_t, self.shape).contiguous()
55 except Exception as e:
56 raise RuntimeError(
57 f"logical_xor_: cannot broadcast 'other' to shape {tuple(self.shape)}: {e}"
58 )
60 # Work on a contiguous copy if self is not contiguous, then copy back
61 work_self = self if self.is_contiguous() else self.contiguous()
63 n_elements = work_self.numel()
64 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
66 logical_xor___triton_kernel[grid](work_self, other_bc, n_elements, BLOCK_SIZE=1024)
68 if work_self.data_ptr() != self.data_ptr():
69 self.copy_(work_self)
71 return self