Coverage for src/flag_gems/experimental_ops/log_.py: 0%
34 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def log_(
8 x_ptr, # *Pointer* to input/output vector (in-place).
9 n_elements, # Number of elements.
10 BLOCK_SIZE: tl.constexpr, # Elements processed per program.
11):
12 pid = tl.program_id(axis=0)
13 block_start = pid * BLOCK_SIZE
14 offsets = block_start + tl.arange(0, BLOCK_SIZE)
15 mask = offsets < n_elements
17 x = tl.load(x_ptr + offsets, mask=mask)
18 x_f32 = x.to(tl.float32)
19 y_f32 = tl.log(x_f32)
20 y = y_f32.to(x.dtype)
21 tl.store(x_ptr + offsets, y, mask=mask)
24# Keep a handle to the Triton kernel before defining the Python wrapper with the same name.
25log__triton_kernel = log_
28def log_(*args, **kwargs):
29 x = args[0] if len(args) > 0 else kwargs.get("input", None)
30 if x is None:
31 raise ValueError("log_ expects a tensor as the first argument.")
32 if not isinstance(x, torch.Tensor):
33 raise TypeError("log_ expects a torch.Tensor as input.")
34 if not x.is_cuda:
35 raise ValueError("Input tensor must be on a CUDA device.")
36 if not x.is_floating_point():
37 raise TypeError("log_ only supports floating point tensors.")
38 if not x.is_contiguous():
39 raise ValueError(
40 "This log_ Triton implementation requires a contiguous tensor."
41 )
43 n_elements = x.numel()
44 if n_elements == 0:
45 return x
47 BLOCK_SIZE = 1024
48 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
50 log__triton_kernel[grid](x, n_elements, BLOCK_SIZE=BLOCK_SIZE)
51 return x