Coverage for src/flag_gems/ops/logit_.py: 52%
63 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
10logger = logging.getLogger(__name__)
13@triton.jit
14def logit_kernel(
15 x_ptr,
16 n_elements,
17 eps,
18 has_eps: tl.constexpr,
19 COMPUTE_FP32: tl.constexpr,
20 COMPUTE_FP64: tl.constexpr,
21 BLOCK_SIZE: tl.constexpr,
22):
23 pid = tl.program_id(axis=0)
24 block_start = pid * BLOCK_SIZE
25 offsets = block_start + tl.arange(0, BLOCK_SIZE)
26 mask = offsets < n_elements
28 x = tl.load(x_ptr + offsets, mask=mask)
30 if COMPUTE_FP32:
31 xc = x.to(tl.float32)
32 if has_eps:
33 xc = tl.maximum(xc, eps)
34 xc = tl.minimum(xc, 1.0 - eps)
35 y = tl.log(xc / (1.0 - xc))
36 out = y.to(x.dtype)
37 elif COMPUTE_FP64:
38 xc = x
39 if has_eps:
40 xc = tl.maximum(xc, eps)
41 xc = tl.minimum(xc, 1.0 - eps)
42 out = tl.log(xc / (1.0 - xc))
43 else:
44 xc = x
45 if has_eps:
46 xc = tl.maximum(xc, eps)
47 xc = tl.minimum(xc, 1.0 - eps)
48 out = tl.log(xc / (1.0 - xc))
50 tl.store(x_ptr + offsets, out, mask=mask)
53def logit_(*args, **kwargs):
54 logger.debug("GEMS LOGIT_")
55 if len(args) == 0:
56 raise TypeError("logit_ expected at least 1 argument (got 0)")
57 x = args[0]
58 eps = None
59 if len(args) > 1:
60 eps = args[1]
61 if "eps" in kwargs:
62 eps = kwargs["eps"]
64 if not isinstance(x, torch.Tensor):
65 raise TypeError("logit_ expects a torch.Tensor as the first argument")
66 if not x.is_floating_point():
67 raise TypeError("logit_ expects a floating point tensor")
69 has_eps = eps is not None
70 eps_value = float(eps) if has_eps else 0.0
72 needs_copy_back = not x.is_contiguous()
73 buf = x if not needs_copy_back else x.contiguous()
75 n_elements = buf.numel()
76 if n_elements == 0:
77 return x
79 dtype = buf.dtype
80 compute_in_fp32 = dtype in (torch.float16, torch.bfloat16)
81 compute_in_fp64 = dtype == torch.float64
83 BLOCK_SIZE = 1024
84 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
86 with torch_device_fn.device(x.device):
87 logit_kernel[grid](
88 buf,
89 n_elements,
90 eps_value,
91 has_eps=has_eps,
92 COMPUTE_FP32=compute_in_fp32,
93 COMPUTE_FP64=compute_in_fp64,
94 BLOCK_SIZE=BLOCK_SIZE,
95 )
97 if needs_copy_back:
98 x.copy_(buf)
100 return x