Coverage for src/flag_gems/experimental_ops/logaddexp.py: 0%

70 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-19 02:32 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def logaddexp_kernel(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 

13 x = tl.load(x_ptr + offsets, mask=mask, other=0.0) 

14 y = tl.load(y_ptr + offsets, mask=mask, other=0.0) 

15 

16 xf32 = x.to(tl.float32) 

17 yf32 = y.to(tl.float32) 

18 

19 delta = xf32 - yf32 

20 adelta = tl.abs(delta) 

21 m = tl.maximum(xf32, yf32) 

22 res = m + tl.log(1.0 + tl.exp(-adelta)) 

23 

24 out_ty = out_ptr.dtype.element_ty 

25 tl.store(out_ptr + offsets, res.to(out_ty), mask=mask) 

26 

27 

28def _ensure_cuda_tensor(obj, device, dtype): 

29 if torch.is_tensor(obj): 

30 return obj.to(device=device, dtype=dtype) 

31 else: 

32 return torch.tensor(obj, device=device, dtype=dtype) 

33 

34 

35def _common_float_dtype(x: torch.Tensor, y: torch.Tensor): 

36 dt = torch.result_type(x, y) 

37 if dt not in (torch.float16, torch.bfloat16, torch.float32, torch.float64): 

38 dt = torch.get_default_dtype() 

39 return dt 

40 

41 

42def _launch_logaddexp_kernel(x: torch.Tensor, y: torch.Tensor, out: torch.Tensor): 

43 assert x.is_cuda and y.is_cuda and out.is_cuda, "All tensors must be on CUDA device" 

44 assert ( 

45 x.numel() == y.numel() == out.numel() 

46 ), "Input and output must have the same number of elements" 

47 

48 x_flat = x.contiguous().view(-1) 

49 y_flat = y.contiguous().view(-1) 

50 out_flat = out.contiguous().view(-1) 

51 

52 n_elements = out_flat.numel() 

53 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

54 logaddexp_kernel[grid](x_flat, y_flat, out_flat, n_elements, BLOCK_SIZE=1024) 

55 

56 # If out was non-contiguous, copy results back into original layout 

57 if not out.is_contiguous(): 

58 out.copy_(out_flat.view_as(out)) 

59 

60 

61def logaddexp(x, y): 

62 # Determine device 

63 device = None 

64 if torch.is_tensor(x) and x.is_cuda: 

65 device = x.device 

66 if device is None and torch.is_tensor(y) and y.is_cuda: 

67 device = y.device 

68 if device is None: 

69 raise ValueError("At least one input must be a CUDA tensor") 

70 

71 # Determine dtype 

72 x_t = x if torch.is_tensor(x) else torch.tensor(x) 

73 y_t = y if torch.is_tensor(y) else torch.tensor(y) 

74 dtype = _common_float_dtype(x_t, y_t) 

75 

76 # Convert to device and dtype 

77 x_t = _ensure_cuda_tensor(x, device, dtype) 

78 y_t = _ensure_cuda_tensor(y, device, dtype) 

79 

80 # Broadcast 

81 xb, yb = torch.broadcast_tensors(x_t, y_t) 

82 

83 # Allocate output 

84 out = torch.empty_like(xb, dtype=dtype, device=device) 

85 

86 _launch_logaddexp_kernel(xb, yb, out) 

87 return out 

88 

89 

90def logaddexp_out(x, y, out): 

91 if not torch.is_tensor(out) or not out.is_cuda: 

92 raise ValueError("out must be a CUDA tensor") 

93 

94 # Determine computation device and dtype from out 

95 device = out.device 

96 out_dtype = out.dtype 

97 if out_dtype not in (torch.float16, torch.bfloat16, torch.float32, torch.float64): 

98 raise ValueError("out dtype must be a floating point type") 

99 

100 # Prepare inputs 

101 x_t = _ensure_cuda_tensor(x, device, out_dtype) 

102 y_t = _ensure_cuda_tensor(y, device, out_dtype) 

103 

104 # Broadcast inputs 

105 xb, yb = torch.broadcast_tensors(x_t, y_t) 

106 

107 # Ensure out shape matches 

108 if tuple(out.shape) != tuple(xb.shape): 

109 raise ValueError( 

110 f"out shape {tuple(out.shape)} does not match broadcasted shape {tuple(xb.shape)}" 

111 ) 

112 

113 _launch_logaddexp_kernel(xb, yb, out) 

114 return out