Coverage for src/flag_gems/experimental_ops/exp2.py: 0%
42 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def exp2_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
10 mask = offsets < n_elements
11 x = tl.load(x_ptr + offsets, mask=mask)
12 ln2 = 0.693147180559945309417232121458176568
13 y = tl.exp(x * ln2)
14 tl.store(out_ptr + offsets, y, mask=mask)
17def exp2(x: torch.Tensor) -> torch.Tensor:
18 if not x.is_cuda:
19 raise ValueError("exp2: input tensor must be on CUDA device")
20 supported_dtypes = (torch.float16, torch.bfloat16, torch.float32)
21 if x.dtype not in supported_dtypes:
22 raise TypeError(
23 f"exp2: unsupported dtype {x.dtype}. Supported: {supported_dtypes}"
24 )
25 x_contig = x.contiguous()
26 out = torch.empty_like(x_contig)
27 n_elements = out.numel()
28 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
29 exp2_kernel[grid](x_contig, out, n_elements, BLOCK_SIZE=1024)
30 return out
33def exp2_out(x: torch.Tensor, out: torch.Tensor) -> torch.Tensor:
34 if not x.is_cuda or not out.is_cuda:
35 raise ValueError("exp2_out: both input and out tensors must be on CUDA device")
36 if x.shape != out.shape:
37 raise ValueError("exp2_out: input and out must have the same shape")
38 if x.dtype != out.dtype:
39 raise TypeError("exp2_out: input and out must have the same dtype")
40 supported_dtypes = (torch.float16, torch.bfloat16, torch.float32)
41 if x.dtype not in supported_dtypes:
42 raise TypeError(
43 f"exp2_out: unsupported dtype {x.dtype}. Supported: {supported_dtypes}"
44 )
45 x_contig = x.contiguous()
46 out_contig = out.contiguous()
47 n_elements = out_contig.numel()
48 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
49 exp2_kernel[grid](x_contig, out_contig, n_elements, BLOCK_SIZE=1024)
50 if out_contig.data_ptr() != out.data_ptr():
51 out.copy_(out_contig)
52 return out