Coverage for src/flag_gems/experimental_ops/logit_.py: 0%
61 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def logit_(
8 x_ptr,
9 n_elements,
10 eps,
11 has_eps: tl.constexpr,
12 COMPUTE_FP32: tl.constexpr,
13 COMPUTE_FP64: tl.constexpr,
14 BLOCK_SIZE: tl.constexpr,
15):
16 pid = tl.program_id(axis=0)
17 block_start = pid * BLOCK_SIZE
18 offsets = block_start + tl.arange(0, BLOCK_SIZE)
19 mask = offsets < n_elements
21 x = tl.load(x_ptr + offsets, mask=mask)
23 # Promote to higher precision for computation if needed
24 if COMPUTE_FP32:
25 xc = x.to(tl.float32)
26 if has_eps:
27 xc = tl.maximum(xc, eps)
28 xc = tl.minimum(xc, 1.0 - eps)
29 y = tl.log(xc / (1.0 - xc))
30 out = y.to(x.dtype)
31 elif COMPUTE_FP64:
32 xc = x # already float64
33 if has_eps:
34 xc = tl.maximum(xc, eps)
35 xc = tl.minimum(xc, 1.0 - eps)
36 out = tl.log(xc / (1.0 - xc))
37 else:
38 # float32 compute
39 xc = x
40 if has_eps:
41 xc = tl.maximum(xc, eps)
42 xc = tl.minimum(xc, 1.0 - eps)
43 out = tl.log(xc / (1.0 - xc))
45 tl.store(x_ptr + offsets, out, mask=mask)
48# Keep a handle to the Triton kernel before defining the Python wrapper with the same name
49logit___kernel = logit_
52def logit_(*args, **kwargs):
53 # Parse arguments similar to torch.logit_(input, eps=None)
54 if len(args) == 0:
55 raise TypeError("logit_ expected at least 1 argument (got 0)")
56 x = args[0]
57 eps = None
58 if len(args) > 1:
59 eps = args[1]
60 if "eps" in kwargs:
61 eps = kwargs["eps"]
63 if not isinstance(x, torch.Tensor):
64 raise TypeError("logit_ expects a torch.Tensor as the first argument")
65 if not x.is_cuda:
66 raise ValueError("logit_ Triton implementation requires a CUDA tensor")
67 if not x.is_floating_point():
68 raise TypeError("logit_ expects a floating point tensor")
70 has_eps = eps is not None
71 eps_value = float(eps) if has_eps else 0.0
73 # Work on a contiguous buffer; copy back if needed to preserve in-place semantics
74 needs_copy_back = not x.is_contiguous()
75 buf = x if not needs_copy_back else x.contiguous()
77 n_elements = buf.numel()
78 if n_elements == 0:
79 return x
81 dtype = buf.dtype
82 compute_in_fp32 = dtype in (torch.float16, torch.bfloat16)
83 compute_in_fp64 = dtype == torch.float64
85 BLOCK_SIZE = 1024
86 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
88 logit___kernel[grid](
89 buf,
90 n_elements,
91 eps_value,
92 has_eps=has_eps,
93 COMPUTE_FP32=compute_in_fp32,
94 COMPUTE_FP64=compute_in_fp64,
95 BLOCK_SIZE=BLOCK_SIZE,
96 )
98 if needs_copy_back:
99 x.copy_(buf)
101 return x