Coverage for src/flag_gems/ops/logit.py: 49%

80 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-25 02:48 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import torch 

3import triton 

4import triton.language as tl 

5 

6from flag_gems.runtime import torch_device_fn 

7 

8 

9@triton.jit 

10def logit_kernel( 

11 x_ptr, 

12 y_ptr, 

13 n_elements, 

14 eps, 

15 HAS_EPS: tl.constexpr, 

16 BLOCK_SIZE: tl.constexpr, 

17 OUT_DTYPE: tl.constexpr, 

18): 

19 pid = tl.program_id(axis=0) 

20 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

21 mask = offsets < n_elements 

22 

23 x = tl.load(x_ptr + offsets, mask=mask, other=0) 

24 x_f32 = x.to(tl.float32) 

25 

26 if HAS_EPS: 

27 lo = eps 

28 hi = 1.0 - eps 

29 x_f32 = tl.minimum(tl.maximum(x_f32, lo), hi) 

30 

31 y = tl.log(x_f32 / (1.0 - x_f32)) 

32 tl.store(y_ptr + offsets, y.to(OUT_DTYPE), mask=mask) 

33 

34 

35def _to_triton_dtype(dtype): 

36 if dtype == torch.float32: 

37 return tl.float32 

38 if dtype == torch.float16: 

39 return tl.float16 

40 if dtype == torch.bfloat16: 

41 return tl.bfloat16 

42 return None 

43 

44 

45def _logit_impl(input: torch.Tensor, eps=None, out: torch.Tensor = None): 

46 if not isinstance(input, torch.Tensor): 

47 raise TypeError("input must be a torch.Tensor") 

48 if not input.is_floating_point(): 

49 raise TypeError("logit expected a floating point tensor as input") 

50 if eps is not None: 

51 eps = float(eps) 

52 if not (0.0 <= eps <= 0.5): 

53 raise ValueError("eps must be in the range [0.0, 0.5].") 

54 

55 in_contig = input.contiguous() 

56 in_supported = _to_triton_dtype(in_contig.dtype) is not None 

57 in_kernel = in_contig if in_supported else in_contig.to(torch.float32) 

58 

59 if out is not None: 

60 if not isinstance(out, torch.Tensor): 

61 raise TypeError("out must be a torch.Tensor") 

62 if out.shape != input.shape: 

63 raise ValueError("out tensor must have the same shape as input") 

64 if out.dtype != input.dtype: 

65 raise TypeError("For logit.out, out.dtype must match input.dtype") 

66 out_supported = _to_triton_dtype(out.dtype) is not None 

67 need_copy_back = (not out.is_contiguous()) or (not out_supported) 

68 

69 if need_copy_back: 

70 work_dtype = out.dtype if out_supported else torch.float32 

71 work_out = torch.empty_like(out, dtype=work_dtype) 

72 else: 

73 work_out = out 

74 

75 n_elements = in_kernel.numel() 

76 BLOCK_SIZE = 1024 

77 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

78 

79 triton_dtype = _to_triton_dtype(work_out.dtype) 

80 with torch_device_fn.device(input.device): 

81 logit_kernel[grid]( 

82 in_kernel, 

83 work_out, 

84 n_elements, 

85 eps if eps is not None else 0.0, 

86 HAS_EPS=(eps is not None), 

87 BLOCK_SIZE=BLOCK_SIZE, 

88 OUT_DTYPE=triton_dtype, 

89 ) 

90 

91 if need_copy_back: 

92 out.copy_(work_out.to(out.dtype)) 

93 return out 

94 

95 desired_dtype = input.dtype 

96 desired_supported = _to_triton_dtype(desired_dtype) is not None 

97 if desired_supported: 

98 result = torch.empty_like(input, dtype=desired_dtype) 

99 work_out = result 

100 else: 

101 work_out = torch.empty_like(input, dtype=torch.float32) 

102 

103 n_elements = in_kernel.numel() 

104 BLOCK_SIZE = 1024 

105 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

106 

107 triton_dtype = _to_triton_dtype(work_out.dtype) 

108 with torch_device_fn.device(input.device): 

109 logit_kernel[grid]( 

110 in_kernel, 

111 work_out, 

112 n_elements, 

113 eps if eps is not None else 0.0, 

114 HAS_EPS=(eps is not None), 

115 BLOCK_SIZE=BLOCK_SIZE, 

116 OUT_DTYPE=triton_dtype, 

117 ) 

118 

119 if desired_supported: 

120 return work_out 

121 else: 

122 return work_out.to(desired_dtype) 

123 

124 

125def logit(input, eps=None): 

126 return _logit_impl(input, eps=eps, out=None) 

127 

128 

129def logit_out(input, eps=None, out=None): 

130 if out is None: 

131 raise TypeError("logit_out requires an 'out' tensor.") 

132 return _logit_impl(input, eps=eps, out=out)