Coverage for src/flag_gems/experimental_ops/special_xlog1py.py: 0%

90 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-09 01:57 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def _xlog1py_kernel(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 x = tl.load(x_ptr + offsets, mask=mask) 

13 y = tl.load(y_ptr + offsets, mask=mask) 

14 out = x * tl.log(1.0 + y) 

15 tl.store(out_ptr + offsets, out, mask=mask) 

16 

17 

18def _ensure_cuda_tensor(t): 

19 if not isinstance(t, torch.Tensor): 

20 raise TypeError("Expected a torch.Tensor") 

21 if t.device.type != "cuda": 

22 raise ValueError("Tensors must be on CUDA device") 

23 return t 

24 

25 

26def _prepare_inputs(x, y): 

27 x = _ensure_cuda_tensor(x) 

28 y = _ensure_cuda_tensor(y) 

29 xb, yb = torch.broadcast_tensors(x, y) 

30 dtype_out = torch.result_type(xb, yb) 

31 xb_fp32 = xb.to(torch.float32).contiguous() 

32 yb_fp32 = yb.to(torch.float32).contiguous() 

33 return xb_fp32, yb_fp32, dtype_out 

34 

35 

36def _launch_xlog1py(x_fp32, y_fp32, out_fp32): 

37 n_elements = out_fp32.numel() 

38 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

39 _xlog1py_kernel[grid](x_fp32, y_fp32, out_fp32, n_elements, BLOCK_SIZE=1024) 

40 

41 

42def special_xlog1py(x, y): 

43 xb_fp32, yb_fp32, dtype_out = _prepare_inputs(x, y) 

44 out_fp32 = torch.empty_like(xb_fp32) 

45 _launch_xlog1py(xb_fp32, yb_fp32, out_fp32) 

46 if dtype_out == torch.float32: 

47 return out_fp32 

48 else: 

49 return out_fp32.to(dtype_out) 

50 

51 

52def special_xlog1py_other_scalar(x, other): 

53 x = _ensure_cuda_tensor(x) 

54 other_tensor = torch.as_tensor(other, device=x.device, dtype=x.dtype) 

55 xb_fp32, yb_fp32, dtype_out = _prepare_inputs(x, other_tensor) 

56 out_fp32 = torch.empty_like(xb_fp32) 

57 _launch_xlog1py(xb_fp32, yb_fp32, out_fp32) 

58 if dtype_out == torch.float32: 

59 return out_fp32 

60 else: 

61 return out_fp32.to(dtype_out) 

62 

63 

64def special_xlog1py_self_scalar(self, other): 

65 other = _ensure_cuda_tensor(other) 

66 self_tensor = torch.as_tensor(self, device=other.device, dtype=other.dtype) 

67 xb_fp32, yb_fp32, dtype_out = _prepare_inputs(self_tensor, other) 

68 out_fp32 = torch.empty_like(xb_fp32) 

69 _launch_xlog1py(xb_fp32, yb_fp32, out_fp32) 

70 if dtype_out == torch.float32: 

71 return out_fp32 

72 else: 

73 return out_fp32.to(dtype_out) 

74 

75 

76def special_xlog1py_out(x, y, out): 

77 out = _ensure_cuda_tensor(out) 

78 xb_fp32, yb_fp32, dtype_out = _prepare_inputs( 

79 _ensure_cuda_tensor(x), _ensure_cuda_tensor(y) 

80 ) 

81 # Validate output shape 

82 expected_shape = torch.broadcast_shapes(xb_fp32.shape, yb_fp32.shape) 

83 if out.shape != expected_shape: 

84 raise ValueError(f"Out tensor has shape {out.shape}, expected {expected_shape}") 

85 out_fp32 = torch.empty(expected_shape, device=out.device, dtype=torch.float32) 

86 _launch_xlog1py(xb_fp32, yb_fp32, out_fp32) 

87 out.copy_(out_fp32.to(out.dtype)) 

88 return out 

89 

90 

91def special_xlog1py_self_scalar_out(self, other, out): 

92 out = _ensure_cuda_tensor(out) 

93 other = _ensure_cuda_tensor(other) 

94 self_tensor = torch.as_tensor(self, device=other.device, dtype=other.dtype) 

95 xb_fp32, yb_fp32, dtype_out = _prepare_inputs(self_tensor, other) 

96 expected_shape = torch.broadcast_shapes(xb_fp32.shape, yb_fp32.shape) 

97 if out.shape != expected_shape: 

98 raise ValueError(f"Out tensor has shape {out.shape}, expected {expected_shape}") 

99 out_fp32 = torch.empty(expected_shape, device=out.device, dtype=torch.float32) 

100 _launch_xlog1py(xb_fp32, yb_fp32, out_fp32) 

101 out.copy_(out_fp32.to(out.dtype)) 

102 return out 

103 

104 

105def special_xlog1py_other_scalar_out(x, other, out): 

106 out = _ensure_cuda_tensor(out) 

107 x = _ensure_cuda_tensor(x) 

108 other_tensor = torch.as_tensor(other, device=x.device, dtype=x.dtype) 

109 xb_fp32, yb_fp32, dtype_out = _prepare_inputs(x, other_tensor) 

110 expected_shape = torch.broadcast_shapes(xb_fp32.shape, yb_fp32.shape) 

111 if out.shape != expected_shape: 

112 raise ValueError(f"Out tensor has shape {out.shape}, expected {expected_shape}") 

113 out_fp32 = torch.empty(expected_shape, device=out.device, dtype=torch.float32) 

114 _launch_xlog1py(xb_fp32, yb_fp32, out_fp32) 

115 out.copy_(out_fp32.to(out.dtype)) 

116 return out