Coverage for src/flag_gems/experimental_ops/i0.py: 0%
57 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def i0_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
13 x = tl.load(x_ptr + offsets, mask=mask, other=0)
14 x_f32 = x.to(tl.float32)
15 ax = tl.abs(x_f32)
17 # Small region: |x| <= 3.75
18 t = x_f32 / 3.75
19 y = t * t
20 p_small = 1.0 + y * (
21 3.5156229
22 + y
23 * (
24 3.0899424
25 + y * (1.2067492 + y * (0.2659732 + y * (0.0360768 + y * 0.0045813)))
26 )
27 )
29 # Large region: |x| > 3.75
30 yb = 3.75 / ax
31 p_big = 0.39894228 + yb * (
32 0.01328592
33 + yb
34 * (
35 0.00225319
36 + yb
37 * (
38 -0.00157565
39 + yb
40 * (
41 0.00916281
42 + yb
43 * (
44 -0.02057706
45 + yb * (0.02635537 + yb * (-0.01647633 + yb * 0.00392377))
46 )
47 )
48 )
49 )
50 )
51 # Avoid division by zero via masking; big branch only used when ax > 3.75
52 res_big = tl.exp(ax) * p_big / tl.sqrt(ax)
54 use_small = ax <= 3.75
55 res = tl.where(use_small, p_small, res_big)
57 # Store result; Triton will cast to the dtype of out_ptr as needed
58 tl.store(out_ptr + offsets, res, mask=mask)
61def _launch_i0(out: torch.Tensor, x: torch.Tensor):
62 assert x.is_cuda and out.is_cuda, "Input and output must be CUDA tensors"
63 assert (
64 out.numel() == x.numel()
65 ), "Input and output must have the same number of elements"
66 assert out.device == x.device, "Input and output must be on the same device"
68 x_in = x
69 out_in = out
71 # Ensure floating point compute
72 if not x_in.is_floating_point():
73 x_in = x_in.to(torch.get_default_dtype())
75 # Cast input to match the desired output dtype if needed
76 # (Compute will be done in fp32 inside kernel; store will cast to out dtype)
77 if x_in.dtype != out_in.dtype:
78 x_in = x_in.to(out_in.dtype)
80 x_contig = x_in.contiguous()
81 out_was_noncontig = not out_in.is_contiguous()
82 out_contig = out_in.contiguous() if out_was_noncontig else out_in
84 n_elements = out_contig.numel()
85 BLOCK_SIZE = 1024
86 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
88 i0_kernel[grid](x_contig, out_contig, n_elements, BLOCK_SIZE=BLOCK_SIZE)
90 if out_was_noncontig:
91 out_in.copy_(out_contig)
92 return out_in
95def i0(x: torch.Tensor):
96 if not x.is_cuda:
97 raise ValueError("i0: input tensor must be on CUDA device")
98 # Result dtype follows PyTorch's floating type behavior; use input dtype if floating, otherwise default
99 out_dtype = x.dtype if x.is_floating_point() else torch.get_default_dtype()
100 out = torch.empty_like(x.to(dtype=out_dtype), dtype=out_dtype, device=x.device)
101 _launch_i0(out, x)
102 return out
105def i0_out(x: torch.Tensor, out: torch.Tensor):
106 if not (x.is_cuda and out.is_cuda):
107 raise ValueError("i0_out: input and output tensors must be on CUDA device")
108 if not out.is_floating_point():
109 raise TypeError("i0_out: output tensor must be a floating point type")
110 if x.numel() != out.numel():
111 raise ValueError(
112 "i0_out: input and output must have the same number of elements"
113 )
114 _launch_i0(out, x)
115 return out