Coverage for src/flag_gems/experimental_ops/log2_.py: 0%
37 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def log2_(x_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
13 val = tl.load(x_ptr + offsets, mask=mask)
14 x = val.to(tl.float32)
16 inv_ln2 = tl.full((), 1.4426950408889634, tl.float32) # 1 / ln(2)
17 y = tl.log(x) * inv_ln2
19 out = y.to(val.dtype)
20 tl.store(x_ptr + offsets, out, mask=mask)
23# Keep a reference to the Triton kernel before redefining the name for the Python wrapper.
24_log2__kernel = log2_
27def log2_(*args, **kwargs):
28 x = args[0] if len(args) > 0 else kwargs.get("input", None)
29 if x is None:
30 raise ValueError("log2_ expects a tensor as the first argument.")
31 if not isinstance(x, torch.Tensor):
32 raise TypeError("log2_ expects a torch.Tensor as input.")
34 # Handle empty tensors directly
35 if x.numel() == 0:
36 return x
38 # Fallback for non-CUDA tensors or unsupported dtypes
39 if (not x.is_cuda) or (
40 x.dtype not in (torch.float16, torch.bfloat16, torch.float32)
41 ):
42 # Use PyTorch's implementation as a fallback
43 x.log2_()
44 return x
46 # Work on a contiguous buffer; copy back if needed
47 x_contig = x if x.is_contiguous() else x.contiguous()
49 n_elements = x_contig.numel()
50 if n_elements == 0:
51 return x
53 BLOCK_SIZE = 1024
54 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
55 _log2__kernel[grid](x_contig, n_elements, BLOCK_SIZE=BLOCK_SIZE)
57 if x_contig is not x:
58 x.copy_(x_contig)
60 return x