Coverage for src/flag_gems/experimental_ops/log_.py: 0%

34 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-20 02:31 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def log_( 

8 x_ptr, # *Pointer* to input/output vector (in-place). 

9 n_elements, # Number of elements. 

10 BLOCK_SIZE: tl.constexpr, # Elements processed per program. 

11): 

12 pid = tl.program_id(axis=0) 

13 block_start = pid * BLOCK_SIZE 

14 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

15 mask = offsets < n_elements 

16 

17 x = tl.load(x_ptr + offsets, mask=mask) 

18 x_f32 = x.to(tl.float32) 

19 y_f32 = tl.log(x_f32) 

20 y = y_f32.to(x.dtype) 

21 tl.store(x_ptr + offsets, y, mask=mask) 

22 

23 

24# Keep a handle to the Triton kernel before defining the Python wrapper with the same name. 

25log__triton_kernel = log_ 

26 

27 

28def log_(*args, **kwargs): 

29 x = args[0] if len(args) > 0 else kwargs.get("input", None) 

30 if x is None: 

31 raise ValueError("log_ expects a tensor as the first argument.") 

32 if not isinstance(x, torch.Tensor): 

33 raise TypeError("log_ expects a torch.Tensor as input.") 

34 if not x.is_cuda: 

35 raise ValueError("Input tensor must be on a CUDA device.") 

36 if not x.is_floating_point(): 

37 raise TypeError("log_ only supports floating point tensors.") 

38 if not x.is_contiguous(): 

39 raise ValueError( 

40 "This log_ Triton implementation requires a contiguous tensor." 

41 ) 

42 

43 n_elements = x.numel() 

44 if n_elements == 0: 

45 return x 

46 

47 BLOCK_SIZE = 1024 

48 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

49 

50 log__triton_kernel[grid](x, n_elements, BLOCK_SIZE=BLOCK_SIZE) 

51 return x