Coverage for src/flag_gems/experimental_ops/logical_xor_.py: 0%

41 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-23 02:03 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def logical_xor_(a_ptr, b_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 

13 a = tl.load(a_ptr + offsets, mask=mask) 

14 b = tl.load(b_ptr + offsets, mask=mask) 

15 

16 a_bool = a != 0 

17 b_bool = b != 0 

18 out = a_bool ^ b_bool 

19 

20 tl.store(a_ptr + offsets, out, mask=mask) 

21 

22 

23# Preserve reference to the Triton kernel before defining the Python wrapper with the same name. 

24logical_xor___triton_kernel = logical_xor_ 

25 

26 

27def logical_xor_(*args, **kwargs): 

28 # Parse inputs: expect (self, other) 

29 if len(args) >= 2: 

30 self, other = args[0], args[1] 

31 else: 

32 self = kwargs.get("input", kwargs.get("self", None)) 

33 other = kwargs.get("other", None) 

34 

35 if not isinstance(self, torch.Tensor): 

36 raise TypeError("logical_xor_: first argument must be a torch.Tensor") 

37 if self.dtype is not torch.bool: 

38 raise RuntimeError( 

39 "logical_xor_: in-place logical operations require self to have dtype torch.bool" 

40 ) 

41 

42 if not self.is_cuda: 

43 raise RuntimeError("logical_xor_: tensor must be on CUDA device") 

44 

45 # Prepare 'other' as tensor on same device 

46 if isinstance(other, torch.Tensor): 

47 other_t = other.to(device=self.device) 

48 else: 

49 # Create scalar tensor with dtype matching self (bool) 

50 other_t = torch.tensor(other, device=self.device, dtype=self.dtype) 

51 

52 # Broadcast 'other' to self's shape and make it contiguous for simple indexing 

53 try: 

54 other_bc = torch.broadcast_to(other_t, self.shape).contiguous() 

55 except Exception as e: 

56 raise RuntimeError( 

57 f"logical_xor_: cannot broadcast 'other' to shape {tuple(self.shape)}: {e}" 

58 ) 

59 

60 # Work on a contiguous copy if self is not contiguous, then copy back 

61 work_self = self if self.is_contiguous() else self.contiguous() 

62 

63 n_elements = work_self.numel() 

64 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

65 

66 logical_xor___triton_kernel[grid](work_self, other_bc, n_elements, BLOCK_SIZE=1024) 

67 

68 if work_self.data_ptr() != self.data_ptr(): 

69 self.copy_(work_self) 

70 

71 return self