Coverage for src/flag_gems/experimental_ops/logit.py: 0%
81 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def logit_kernel(
8 x_ptr,
9 y_ptr,
10 n_elements,
11 eps,
12 HAS_EPS: tl.constexpr,
13 BLOCK_SIZE: tl.constexpr,
14 OUT_DTYPE: tl.constexpr,
15):
16 pid = tl.program_id(axis=0)
17 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
18 mask = offsets < n_elements
20 x = tl.load(x_ptr + offsets, mask=mask, other=0)
21 x_f32 = x.to(tl.float32)
23 if HAS_EPS:
24 lo = eps
25 hi = 1.0 - eps
26 x_f32 = tl.minimum(tl.maximum(x_f32, lo), hi)
28 y = tl.log(x_f32 / (1.0 - x_f32))
29 tl.store(y_ptr + offsets, y.to(OUT_DTYPE), mask=mask)
32def _to_triton_dtype(dtype):
33 if dtype == torch.float32:
34 return tl.float32
35 if dtype == torch.float16:
36 return tl.float16
37 if dtype == torch.bfloat16:
38 return tl.bfloat16
39 return None
42def _logit_impl(input: torch.Tensor, eps=None, out: torch.Tensor = None):
43 if not isinstance(input, torch.Tensor):
44 raise TypeError("input must be a torch.Tensor")
45 if not input.is_cuda:
46 raise AssertionError("Input tensor must be on CUDA device for Triton kernel.")
47 if not input.is_floating_point():
48 raise TypeError("logit expected a floating point tensor as input")
49 if eps is not None:
50 eps = float(eps)
51 if not (0.0 <= eps <= 0.5):
52 raise ValueError("eps must be in the range [0.0, 0.5].")
54 in_contig = input.contiguous()
55 in_supported = _to_triton_dtype(in_contig.dtype) is not None
56 in_kernel = in_contig if in_supported else in_contig.to(torch.float32)
58 if out is not None:
59 if not isinstance(out, torch.Tensor):
60 raise TypeError("out must be a torch.Tensor")
61 if not out.is_cuda:
62 raise AssertionError("Out tensor must be on CUDA device for Triton kernel.")
63 if out.shape != input.shape:
64 raise ValueError("out tensor must have the same shape as input")
65 if out.dtype != input.dtype:
66 raise TypeError("For logit.out, out.dtype must match input.dtype")
67 # Decide working output (contiguous and with supported dtype)
68 out_supported = _to_triton_dtype(out.dtype) is not None
69 need_copy_back = (not out.is_contiguous()) or (not out_supported)
71 if need_copy_back:
72 work_dtype = out.dtype if out_supported else torch.float32
73 work_out = torch.empty_like(out, dtype=work_dtype)
74 else:
75 work_out = out
77 n_elements = in_kernel.numel()
78 BLOCK_SIZE = 1024
79 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
81 triton_dtype = _to_triton_dtype(work_out.dtype)
82 logit_kernel[grid](
83 in_kernel,
84 work_out,
85 n_elements,
86 eps if eps is not None else 0.0,
87 HAS_EPS=(eps is not None),
88 BLOCK_SIZE=BLOCK_SIZE,
89 OUT_DTYPE=triton_dtype,
90 )
92 if need_copy_back:
93 out.copy_(work_out.to(out.dtype))
94 return out
96 # out is None -> produce and return a new tensor
97 desired_dtype = input.dtype
98 desired_supported = _to_triton_dtype(desired_dtype) is not None
99 if desired_supported:
100 result = torch.empty_like(input, dtype=desired_dtype)
101 work_out = result
102 else:
103 # compute in fp32, cast back to desired
104 work_out = torch.empty_like(input, dtype=torch.float32)
106 n_elements = in_kernel.numel()
107 BLOCK_SIZE = 1024
108 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
110 triton_dtype = _to_triton_dtype(work_out.dtype)
111 logit_kernel[grid](
112 in_kernel,
113 work_out,
114 n_elements,
115 eps if eps is not None else 0.0,
116 HAS_EPS=(eps is not None),
117 BLOCK_SIZE=BLOCK_SIZE,
118 OUT_DTYPE=triton_dtype,
119 )
121 if desired_supported:
122 return work_out
123 else:
124 return work_out.to(desired_dtype)
127def logit(input, eps=None):
128 return _logit_impl(input, eps=eps, out=None)
131def logit_out(input, eps=None, out=None):
132 if out is None:
133 raise TypeError("logit_out requires an 'out' tensor.")
134 return _logit_impl(input, eps=eps, out=out)