Coverage for src/flag_gems/experimental_ops/log2_.py: 0%

37 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-21 14:31 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def log2_(x_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 

13 val = tl.load(x_ptr + offsets, mask=mask) 

14 x = val.to(tl.float32) 

15 

16 inv_ln2 = tl.full((), 1.4426950408889634, tl.float32) # 1 / ln(2) 

17 y = tl.log(x) * inv_ln2 

18 

19 out = y.to(val.dtype) 

20 tl.store(x_ptr + offsets, out, mask=mask) 

21 

22 

23# Keep a reference to the Triton kernel before redefining the name for the Python wrapper. 

24_log2__kernel = log2_ 

25 

26 

27def log2_(*args, **kwargs): 

28 x = args[0] if len(args) > 0 else kwargs.get("input", None) 

29 if x is None: 

30 raise ValueError("log2_ expects a tensor as the first argument.") 

31 if not isinstance(x, torch.Tensor): 

32 raise TypeError("log2_ expects a torch.Tensor as input.") 

33 

34 # Handle empty tensors directly 

35 if x.numel() == 0: 

36 return x 

37 

38 # Fallback for non-CUDA tensors or unsupported dtypes 

39 if (not x.is_cuda) or ( 

40 x.dtype not in (torch.float16, torch.bfloat16, torch.float32) 

41 ): 

42 # Use PyTorch's implementation as a fallback 

43 x.log2_() 

44 return x 

45 

46 # Work on a contiguous buffer; copy back if needed 

47 x_contig = x if x.is_contiguous() else x.contiguous() 

48 

49 n_elements = x_contig.numel() 

50 if n_elements == 0: 

51 return x 

52 

53 BLOCK_SIZE = 1024 

54 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

55 _log2__kernel[grid](x_contig, n_elements, BLOCK_SIZE=BLOCK_SIZE) 

56 

57 if x_contig is not x: 

58 x.copy_(x_contig) 

59 

60 return x