Coverage for src/flag_gems/experimental_ops/logaddexp2.py: 0%

68 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-21 14:31 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def logaddexp2_kernel(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

10 mask = offs < n_elements 

11 

12 # Load inputs and upcast to fp32 for numerics 

13 x = tl.load(x_ptr + offs, mask=mask, other=0).to(tl.float32) 

14 y = tl.load(y_ptr + offs, mask=mask, other=0).to(tl.float32) 

15 

16 # Numerically-stable logaddexp2: 

17 # logaddexp2(x, y) = m + log2(1 + 2^(-|x - y|)), where m = max(x, y) 

18 ln2 = 0.6931471805599453 

19 inv_ln2 = 1.4426950408889634 

20 

21 d = tl.abs(x - y) 

22 m = tl.maximum(x, y) 

23 t = tl.exp(-d * ln2) # 2^(-|x-y|) = exp(-(abs(x-y)) * ln(2)) 

24 res = m + tl.log(1.0 + t) * inv_ln2 # log2(1 + t) = ln(1+t) / ln(2) 

25 

26 # Store; Triton will cast to the dtype of out_ptr as needed 

27 tl.store(out_ptr + offs, res, mask=mask) 

28 

29 

30def _broadcast_and_check(x, y): 

31 # Convert scalars to tensors 

32 if not isinstance(x, torch.Tensor): 

33 x = torch.as_tensor(x) 

34 if not isinstance(y, torch.Tensor): 

35 y = torch.as_tensor(y) 

36 # Broadcast 

37 bx, by = torch.broadcast_tensors(x, y) 

38 return bx, by 

39 

40 

41def _choose_out_dtype(x, y, out=None): 

42 if out is not None: 

43 return out.dtype 

44 # Prefer highest precision floating dtype present; else default dtype 

45 float_priority = [torch.float64, torch.float32, torch.bfloat16, torch.float16] 

46 for dt in float_priority: 

47 if x.dtype == dt or y.dtype == dt: 

48 return dt 

49 # If none are floating, use default dtype 

50 return torch.get_default_dtype() 

51 

52 

53def _launch_kernel(xc, yc, outc): 

54 n_elements = outc.numel() 

55 if n_elements == 0: 

56 return 

57 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

58 logaddexp2_kernel[grid](xc, yc, outc, n_elements, BLOCK_SIZE=1024) 

59 

60 

61def logaddexp2(x, y): 

62 bx, by = _broadcast_and_check(x, y) 

63 

64 # Fallback for unsupported devices or complex dtype 

65 if ( 

66 bx.device.type != "cuda" 

67 or by.device.type != "cuda" 

68 or bx.device != by.device 

69 or bx.is_complex() 

70 or by.is_complex() 

71 ): 

72 return torch.ops.aten.logaddexp2(bx, by) 

73 

74 out_dtype = _choose_out_dtype(bx, by, out=None) 

75 out = torch.empty(bx.shape, device=bx.device, dtype=out_dtype) 

76 

77 # Ensure contiguous 1D buffers for the kernel 

78 xc = bx.contiguous().view(-1) 

79 yc = by.contiguous().view(-1) 

80 outc = out.contiguous().view(-1) 

81 

82 _launch_kernel(xc, yc, outc) 

83 return out 

84 

85 

86def logaddexp2_out(x, y, out): 

87 if out is None: 

88 raise ValueError("out tensor must be provided for logaddexp2_out") 

89 

90 bx, by = _broadcast_and_check(x, y) 

91 

92 # Fallback for unsupported devices or complex dtype 

93 if ( 

94 out.device.type != "cuda" 

95 or bx.device.type != "cuda" 

96 or by.device.type != "cuda" 

97 or not (bx.device == by.device == out.device) 

98 or bx.is_complex() 

99 or by.is_complex() 

100 or out.is_complex() 

101 ): 

102 # Use PyTorch implementation for unsupported cases 

103 return torch.ops.aten.logaddexp2.out(bx, by, out=out) 

104 

105 # Shape and dtype checks 

106 if out.shape != bx.shape: 

107 raise ValueError( 

108 f"out tensor has shape {out.shape}, expected {bx.shape} from broadcast" 

109 ) 

110 # We allow dtype differences; computation will write to out's dtype 

111 

112 # Prepare contiguous buffers 

113 xc = bx.contiguous().view(-1) 

114 yc = by.contiguous().view(-1) 

115 

116 if out.is_contiguous(): 

117 outc = out.view(-1) 

118 _launch_kernel(xc, yc, outc) 

119 return out 

120 else: 

121 # Compute into a temporary contiguous buffer then copy back 

122 tmp = torch.empty_like(out, memory_format=torch.contiguous_format) 

123 outc = tmp.view(-1) 

124 _launch_kernel(xc, yc, outc) 

125 out.copy_(tmp) 

126 return out