Coverage for src/flag_gems/experimental_ops/xlogy.py: 0%

75 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-26 15:32 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def xlogy_kernel(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 

13 x = tl.load(x_ptr + offsets, mask=mask, other=0) 

14 y = tl.load(y_ptr + offsets, mask=mask, other=1) 

15 

16 x_f32 = x.to(tl.float32) 

17 y_f32 = y.to(tl.float32) 

18 

19 # result = where(x == 0, 0, x * log(y)) 

20 res = tl.where(x_f32 == 0.0, 0.0, x_f32 * tl.log(y_f32)) 

21 

22 tl.store(out_ptr + offsets, res, mask=mask) 

23 

24 

25def _ensure_tensor_on_device(obj, device, dtype): 

26 if isinstance(obj, torch.Tensor): 

27 return obj.to(device=device, dtype=dtype) 

28 else: 

29 return torch.as_tensor(obj, device=device, dtype=dtype) 

30 

31 

32def _prepare_tensors(self, other, out=None): 

33 # Determine device 

34 if isinstance(self, torch.Tensor): 

35 device = self.device 

36 elif isinstance(other, torch.Tensor): 

37 device = other.device 

38 else: 

39 raise ValueError("At least one of the inputs must be a Tensor.") 

40 

41 if device.type != "cuda": 

42 raise ValueError("Triton kernels require CUDA tensors.") 

43 

44 # Type promotion following PyTorch semantics 

45 if isinstance(self, torch.Tensor) and isinstance(other, torch.Tensor): 

46 result_dtype = torch.result_type(self, other) 

47 elif isinstance(self, torch.Tensor): 

48 other_tmp = torch.as_tensor(other) 

49 result_dtype = torch.result_type(self, other_tmp) 

50 else: 

51 self_tmp = torch.as_tensor(self) 

52 result_dtype = torch.result_type(self_tmp, other) 

53 

54 t_self = _ensure_tensor_on_device(self, device, result_dtype) 

55 t_other = _ensure_tensor_on_device(other, device, result_dtype) 

56 

57 # Broadcast to a common shape 

58 b_self, b_other = torch.broadcast_tensors(t_self, t_other) 

59 

60 # Prepare output 

61 if out is None: 

62 out_tensor = torch.empty(b_self.shape, device=device, dtype=result_dtype) 

63 return b_self.contiguous(), b_other.contiguous(), out_tensor, out_tensor 

64 else: 

65 if out.device != device: 

66 raise ValueError("Output tensor must be on the same device as inputs.") 

67 # Out dtype/shape should be able to hold result 

68 expected_shape = b_self.shape 

69 if out.shape != expected_shape: 

70 raise ValueError( 

71 f"Output tensor has shape {out.shape}, expected {expected_shape}." 

72 ) 

73 if out.dtype != result_dtype: 

74 raise ValueError( 

75 f"Output tensor has dtype {out.dtype}, expected {result_dtype}." 

76 ) 

77 # If out is contiguous, write directly; otherwise use a temporary 

78 if out.is_contiguous(): 

79 return b_self.contiguous(), b_other.contiguous(), out, out 

80 else: 

81 tmp = torch.empty(expected_shape, device=device, dtype=result_dtype) 

82 return b_self.contiguous(), b_other.contiguous(), tmp, out 

83 

84 

85def _launch_xlogy(self, other, out=None): 

86 x, y, dst, final_out = _prepare_tensors(self, other, out) 

87 n_elements = dst.numel() 

88 if n_elements == 0: 

89 if final_out is not dst: 

90 final_out.copy_(dst) 

91 return final_out 

92 

93 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

94 xlogy_kernel[grid](x, y, dst, n_elements, BLOCK_SIZE=1024) 

95 

96 if final_out is not dst: 

97 final_out.copy_(dst) 

98 return final_out 

99 

100 

101# Wrappers corresponding to ATen operator interfaces 

102def xlogy_Tensor(self: torch.Tensor, other: torch.Tensor): 

103 return _launch_xlogy(self, other, out=None) 

104 

105 

106def xlogy_Scalar_Other(self: torch.Tensor, other): 

107 return _launch_xlogy(self, other, out=None) 

108 

109 

110def xlogy_Scalar_Self(self, other: torch.Tensor): 

111 return _launch_xlogy(self, other, out=None) 

112 

113 

114def xlogy_OutTensor(self: torch.Tensor, other: torch.Tensor, out: torch.Tensor): 

115 return _launch_xlogy(self, other, out=out) 

116 

117 

118def xlogy_OutScalar_Self(self, other: torch.Tensor, out: torch.Tensor): 

119 return _launch_xlogy(self, other, out=out) 

120 

121 

122def xlogy_OutScalar_Other(self: torch.Tensor, other, out: torch.Tensor): 

123 return _launch_xlogy(self, other, out=out)