Coverage for src/flag_gems/experimental_ops/exp2.py: 0%

42 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-18 02:36 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def exp2_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

10 mask = offsets < n_elements 

11 x = tl.load(x_ptr + offsets, mask=mask) 

12 ln2 = 0.693147180559945309417232121458176568 

13 y = tl.exp(x * ln2) 

14 tl.store(out_ptr + offsets, y, mask=mask) 

15 

16 

17def exp2(x: torch.Tensor) -> torch.Tensor: 

18 if not x.is_cuda: 

19 raise ValueError("exp2: input tensor must be on CUDA device") 

20 supported_dtypes = (torch.float16, torch.bfloat16, torch.float32) 

21 if x.dtype not in supported_dtypes: 

22 raise TypeError( 

23 f"exp2: unsupported dtype {x.dtype}. Supported: {supported_dtypes}" 

24 ) 

25 x_contig = x.contiguous() 

26 out = torch.empty_like(x_contig) 

27 n_elements = out.numel() 

28 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

29 exp2_kernel[grid](x_contig, out, n_elements, BLOCK_SIZE=1024) 

30 return out 

31 

32 

33def exp2_out(x: torch.Tensor, out: torch.Tensor) -> torch.Tensor: 

34 if not x.is_cuda or not out.is_cuda: 

35 raise ValueError("exp2_out: both input and out tensors must be on CUDA device") 

36 if x.shape != out.shape: 

37 raise ValueError("exp2_out: input and out must have the same shape") 

38 if x.dtype != out.dtype: 

39 raise TypeError("exp2_out: input and out must have the same dtype") 

40 supported_dtypes = (torch.float16, torch.bfloat16, torch.float32) 

41 if x.dtype not in supported_dtypes: 

42 raise TypeError( 

43 f"exp2_out: unsupported dtype {x.dtype}. Supported: {supported_dtypes}" 

44 ) 

45 x_contig = x.contiguous() 

46 out_contig = out.contiguous() 

47 n_elements = out_contig.numel() 

48 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

49 exp2_kernel[grid](x_contig, out_contig, n_elements, BLOCK_SIZE=1024) 

50 if out_contig.data_ptr() != out.data_ptr(): 

51 out.copy_(out_contig) 

52 return out