Coverage for src/flag_gems/ops/logit.py: 49%
80 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import torch
3import triton
4import triton.language as tl
6from flag_gems.runtime import torch_device_fn
9@triton.jit
10def logit_kernel(
11 x_ptr,
12 y_ptr,
13 n_elements,
14 eps,
15 HAS_EPS: tl.constexpr,
16 BLOCK_SIZE: tl.constexpr,
17 OUT_DTYPE: tl.constexpr,
18):
19 pid = tl.program_id(axis=0)
20 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
21 mask = offsets < n_elements
23 x = tl.load(x_ptr + offsets, mask=mask, other=0)
24 x_f32 = x.to(tl.float32)
26 if HAS_EPS:
27 lo = eps
28 hi = 1.0 - eps
29 x_f32 = tl.minimum(tl.maximum(x_f32, lo), hi)
31 y = tl.log(x_f32 / (1.0 - x_f32))
32 tl.store(y_ptr + offsets, y.to(OUT_DTYPE), mask=mask)
35def _to_triton_dtype(dtype):
36 if dtype == torch.float32:
37 return tl.float32
38 if dtype == torch.float16:
39 return tl.float16
40 if dtype == torch.bfloat16:
41 return tl.bfloat16
42 return None
45def _logit_impl(input: torch.Tensor, eps=None, out: torch.Tensor = None):
46 if not isinstance(input, torch.Tensor):
47 raise TypeError("input must be a torch.Tensor")
48 if not input.is_floating_point():
49 raise TypeError("logit expected a floating point tensor as input")
50 if eps is not None:
51 eps = float(eps)
52 if not (0.0 <= eps <= 0.5):
53 raise ValueError("eps must be in the range [0.0, 0.5].")
55 in_contig = input.contiguous()
56 in_supported = _to_triton_dtype(in_contig.dtype) is not None
57 in_kernel = in_contig if in_supported else in_contig.to(torch.float32)
59 if out is not None:
60 if not isinstance(out, torch.Tensor):
61 raise TypeError("out must be a torch.Tensor")
62 if out.shape != input.shape:
63 raise ValueError("out tensor must have the same shape as input")
64 if out.dtype != input.dtype:
65 raise TypeError("For logit.out, out.dtype must match input.dtype")
66 out_supported = _to_triton_dtype(out.dtype) is not None
67 need_copy_back = (not out.is_contiguous()) or (not out_supported)
69 if need_copy_back:
70 work_dtype = out.dtype if out_supported else torch.float32
71 work_out = torch.empty_like(out, dtype=work_dtype)
72 else:
73 work_out = out
75 n_elements = in_kernel.numel()
76 BLOCK_SIZE = 1024
77 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
79 triton_dtype = _to_triton_dtype(work_out.dtype)
80 with torch_device_fn.device(input.device):
81 logit_kernel[grid](
82 in_kernel,
83 work_out,
84 n_elements,
85 eps if eps is not None else 0.0,
86 HAS_EPS=(eps is not None),
87 BLOCK_SIZE=BLOCK_SIZE,
88 OUT_DTYPE=triton_dtype,
89 )
91 if need_copy_back:
92 out.copy_(work_out.to(out.dtype))
93 return out
95 desired_dtype = input.dtype
96 desired_supported = _to_triton_dtype(desired_dtype) is not None
97 if desired_supported:
98 result = torch.empty_like(input, dtype=desired_dtype)
99 work_out = result
100 else:
101 work_out = torch.empty_like(input, dtype=torch.float32)
103 n_elements = in_kernel.numel()
104 BLOCK_SIZE = 1024
105 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
107 triton_dtype = _to_triton_dtype(work_out.dtype)
108 with torch_device_fn.device(input.device):
109 logit_kernel[grid](
110 in_kernel,
111 work_out,
112 n_elements,
113 eps if eps is not None else 0.0,
114 HAS_EPS=(eps is not None),
115 BLOCK_SIZE=BLOCK_SIZE,
116 OUT_DTYPE=triton_dtype,
117 )
119 if desired_supported:
120 return work_out
121 else:
122 return work_out.to(desired_dtype)
125def logit(input, eps=None):
126 return _logit_impl(input, eps=eps, out=None)
129def logit_out(input, eps=None, out=None):
130 if out is None:
131 raise TypeError("logit_out requires an 'out' tensor.")
132 return _logit_impl(input, eps=eps, out=out)