Coverage for src/flag_gems/ops/logaddexp.py: 71%

75 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@triton.jit 

14def logaddexp_kernel(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

15 pid = tl.program_id(axis=0) 

16 block_start = pid * BLOCK_SIZE 

17 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

18 mask = offsets < n_elements 

19 

20 x = tl.load(x_ptr + offsets, mask=mask, other=0.0) 

21 y = tl.load(y_ptr + offsets, mask=mask, other=0.0) 

22 

23 xf32 = x.to(tl.float32) 

24 yf32 = y.to(tl.float32) 

25 

26 delta = xf32 - yf32 

27 adelta = tl.abs(delta) 

28 m = tl.maximum(xf32, yf32) 

29 res = m + tl.log(1.0 + tl.exp(-adelta)) 

30 

31 out_ty = out_ptr.dtype.element_ty 

32 tl.store(out_ptr + offsets, res.to(out_ty), mask=mask) 

33 

34 

35def _ensure_cuda_tensor(obj, device, dtype): 

36 if torch.is_tensor(obj): 

37 return obj.to(device=device, dtype=dtype) 

38 else: 

39 return torch.tensor(obj, device=device, dtype=dtype) 

40 

41 

42def _common_float_dtype(x: torch.Tensor, y: torch.Tensor): 

43 dt = torch.result_type(x, y) 

44 if dt not in (torch.float16, torch.bfloat16, torch.float32, torch.float64): 

45 dt = torch.get_default_dtype() 

46 return dt 

47 

48 

49def _launch_logaddexp_kernel(x: torch.Tensor, y: torch.Tensor, out: torch.Tensor): 

50 assert ( 

51 x.numel() == y.numel() == out.numel() 

52 ), "Input and output must have the same number of elements" 

53 

54 x_flat = x.contiguous().view(-1) 

55 y_flat = y.contiguous().view(-1) 

56 out_flat = out.contiguous().view(-1) 

57 

58 n_elements = out_flat.numel() 

59 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

60 with torch_device_fn.device(x.device): 

61 logaddexp_kernel[grid](x_flat, y_flat, out_flat, n_elements, BLOCK_SIZE=1024) 

62 

63 # If out was non-contiguous, copy results back into original layout 

64 if not out.is_contiguous(): 

65 out.copy_(out_flat.view_as(out)) 

66 

67 

68def logaddexp(x, y): 

69 logger.debug("GEMS LOGADDEXP") 

70 # Determine device 

71 device = None 

72 if torch.is_tensor(x) and x.is_cuda: 

73 device = x.device 

74 if device is None and torch.is_tensor(y) and y.is_cuda: 

75 device = y.device 

76 if device is None: 

77 raise ValueError("At least one input must be a CUDA tensor") 

78 

79 # Determine dtype 

80 x_t = x if torch.is_tensor(x) else torch.tensor(x) 

81 y_t = y if torch.is_tensor(y) else torch.tensor(y) 

82 dtype = _common_float_dtype(x_t, y_t) 

83 

84 # Convert to device and dtype 

85 x_t = _ensure_cuda_tensor(x, device, dtype) 

86 y_t = _ensure_cuda_tensor(y, device, dtype) 

87 

88 # Broadcast 

89 xb, yb = torch.broadcast_tensors(x_t, y_t) 

90 

91 # Allocate output 

92 out = torch.empty_like(xb, dtype=dtype, device=device) 

93 

94 _launch_logaddexp_kernel(xb, yb, out) 

95 return out 

96 

97 

98def logaddexp_out(x, y, out): 

99 logger.debug("GEMS LOGADDEXP_OUT") 

100 if not torch.is_tensor(out): 

101 raise ValueError("out must be a tensor") 

102 

103 # Determine computation device and dtype from out 

104 device = out.device 

105 out_dtype = out.dtype 

106 if out_dtype not in (torch.float16, torch.bfloat16, torch.float32, torch.float64): 

107 raise ValueError("out dtype must be a floating point type") 

108 

109 # Prepare inputs 

110 x_t = _ensure_cuda_tensor(x, device, out_dtype) 

111 y_t = _ensure_cuda_tensor(y, device, out_dtype) 

112 

113 # Broadcast inputs 

114 xb, yb = torch.broadcast_tensors(x_t, y_t) 

115 

116 # Ensure out shape matches 

117 if tuple(out.shape) != tuple(xb.shape): 

118 raise ValueError( 

119 f"out shape {tuple(out.shape)} does not match broadcasted shape {tuple(xb.shape)}" 

120 ) 

121 

122 _launch_logaddexp_kernel(xb, yb, out) 

123 return out