Coverage for src/flag_gems/experimental_ops/logit.py: 0%

81 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-20 02:31 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def logit_kernel( 

8 x_ptr, 

9 y_ptr, 

10 n_elements, 

11 eps, 

12 HAS_EPS: tl.constexpr, 

13 BLOCK_SIZE: tl.constexpr, 

14 OUT_DTYPE: tl.constexpr, 

15): 

16 pid = tl.program_id(axis=0) 

17 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

18 mask = offsets < n_elements 

19 

20 x = tl.load(x_ptr + offsets, mask=mask, other=0) 

21 x_f32 = x.to(tl.float32) 

22 

23 if HAS_EPS: 

24 lo = eps 

25 hi = 1.0 - eps 

26 x_f32 = tl.minimum(tl.maximum(x_f32, lo), hi) 

27 

28 y = tl.log(x_f32 / (1.0 - x_f32)) 

29 tl.store(y_ptr + offsets, y.to(OUT_DTYPE), mask=mask) 

30 

31 

32def _to_triton_dtype(dtype): 

33 if dtype == torch.float32: 

34 return tl.float32 

35 if dtype == torch.float16: 

36 return tl.float16 

37 if dtype == torch.bfloat16: 

38 return tl.bfloat16 

39 return None 

40 

41 

42def _logit_impl(input: torch.Tensor, eps=None, out: torch.Tensor = None): 

43 if not isinstance(input, torch.Tensor): 

44 raise TypeError("input must be a torch.Tensor") 

45 if not input.is_cuda: 

46 raise AssertionError("Input tensor must be on CUDA device for Triton kernel.") 

47 if not input.is_floating_point(): 

48 raise TypeError("logit expected a floating point tensor as input") 

49 if eps is not None: 

50 eps = float(eps) 

51 if not (0.0 <= eps <= 0.5): 

52 raise ValueError("eps must be in the range [0.0, 0.5].") 

53 

54 in_contig = input.contiguous() 

55 in_supported = _to_triton_dtype(in_contig.dtype) is not None 

56 in_kernel = in_contig if in_supported else in_contig.to(torch.float32) 

57 

58 if out is not None: 

59 if not isinstance(out, torch.Tensor): 

60 raise TypeError("out must be a torch.Tensor") 

61 if not out.is_cuda: 

62 raise AssertionError("Out tensor must be on CUDA device for Triton kernel.") 

63 if out.shape != input.shape: 

64 raise ValueError("out tensor must have the same shape as input") 

65 if out.dtype != input.dtype: 

66 raise TypeError("For logit.out, out.dtype must match input.dtype") 

67 # Decide working output (contiguous and with supported dtype) 

68 out_supported = _to_triton_dtype(out.dtype) is not None 

69 need_copy_back = (not out.is_contiguous()) or (not out_supported) 

70 

71 if need_copy_back: 

72 work_dtype = out.dtype if out_supported else torch.float32 

73 work_out = torch.empty_like(out, dtype=work_dtype) 

74 else: 

75 work_out = out 

76 

77 n_elements = in_kernel.numel() 

78 BLOCK_SIZE = 1024 

79 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

80 

81 triton_dtype = _to_triton_dtype(work_out.dtype) 

82 logit_kernel[grid]( 

83 in_kernel, 

84 work_out, 

85 n_elements, 

86 eps if eps is not None else 0.0, 

87 HAS_EPS=(eps is not None), 

88 BLOCK_SIZE=BLOCK_SIZE, 

89 OUT_DTYPE=triton_dtype, 

90 ) 

91 

92 if need_copy_back: 

93 out.copy_(work_out.to(out.dtype)) 

94 return out 

95 

96 # out is None -> produce and return a new tensor 

97 desired_dtype = input.dtype 

98 desired_supported = _to_triton_dtype(desired_dtype) is not None 

99 if desired_supported: 

100 result = torch.empty_like(input, dtype=desired_dtype) 

101 work_out = result 

102 else: 

103 # compute in fp32, cast back to desired 

104 work_out = torch.empty_like(input, dtype=torch.float32) 

105 

106 n_elements = in_kernel.numel() 

107 BLOCK_SIZE = 1024 

108 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

109 

110 triton_dtype = _to_triton_dtype(work_out.dtype) 

111 logit_kernel[grid]( 

112 in_kernel, 

113 work_out, 

114 n_elements, 

115 eps if eps is not None else 0.0, 

116 HAS_EPS=(eps is not None), 

117 BLOCK_SIZE=BLOCK_SIZE, 

118 OUT_DTYPE=triton_dtype, 

119 ) 

120 

121 if desired_supported: 

122 return work_out 

123 else: 

124 return work_out.to(desired_dtype) 

125 

126 

127def logit(input, eps=None): 

128 return _logit_impl(input, eps=eps, out=None) 

129 

130 

131def logit_out(input, eps=None, out=None): 

132 if out is None: 

133 raise TypeError("logit_out requires an 'out' tensor.") 

134 return _logit_impl(input, eps=eps, out=out)