Coverage for src/flag_gems/ops/special_i0e.py: 53%
45 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import torch
3import triton
4import triton.language as tl
7@triton.jit
8def _special_i0e_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
9 pid = tl.program_id(axis=0)
10 block_start = pid * BLOCK_SIZE
11 offsets = block_start + tl.arange(0, BLOCK_SIZE)
12 mask = offsets < n_elements
14 x = tl.load(x_ptr + offsets, mask=mask)
16 # Compute in fp32 for accuracy/stability
17 xf = x.to(tl.float32)
18 ax = tl.abs(xf)
20 # Small region: x <= 3.75
21 t_small = ax / 3.75
22 t2 = t_small * t_small
23 # Polynomial approximation for I0 in small region (Numerical Recipes)
24 p = 1.0 + t2 * (
25 3.5156229
26 + t2
27 * (
28 3.0899424
29 + t2 * (1.2067492 + t2 * (0.2659732 + t2 * (0.0360768 + t2 * 0.0045813)))
30 )
31 )
32 small = p * tl.exp(-ax)
34 # Large region: x > 3.75, use asymptotic expansion to avoid exp overflow
35 # i0e(x) = I0(x)*exp(-|x|) ≈ (1/sqrt(|x|)) * poly(3.75/|x|)
36 t = 3.75 / ax
37 q = 0.39894228 + t * (
38 0.01328592
39 + t
40 * (
41 0.00225319
42 + t
43 * (
44 -0.00157565
45 + t
46 * (
47 0.00916281
48 + t
49 * (
50 -0.02057706
51 + t * (0.02635537 + t * (-0.01647633 + t * 0.00392377))
52 )
53 )
54 )
55 )
56 )
57 large = q / tl.sqrt(ax)
59 is_large = ax > 3.75
60 y = tl.where(is_large, large, small)
62 # Cast back to input dtype for storage
63 y = y.to(x.dtype)
64 tl.store(out_ptr + offsets, y, mask=mask)
67def _run_special_i0e_kernel(x: torch.Tensor, out: torch.Tensor):
68 assert x.is_cuda and out.is_cuda, "Tensors must be CUDA tensors"
69 assert x.dtype in (
70 torch.float16,
71 torch.bfloat16,
72 torch.float32,
73 torch.float64,
74 ), "Unsupported dtype"
75 assert out.dtype == x.dtype, "Output dtype must match input dtype"
77 x_c = x.contiguous()
78 out_c = out.contiguous()
80 n_elements = out_c.numel()
81 if n_elements == 0:
82 return out
84 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
85 _special_i0e_kernel[grid](x_c, out_c, n_elements, BLOCK_SIZE=1024)
87 if out_c.data_ptr() != out.data_ptr():
88 out.copy_(out_c)
89 return out
92def special_i0e(x: torch.Tensor):
93 """
94 ATen wrapper: special_i0e(Tensor self) -> Tensor
95 """
96 out = torch.empty_like(x)
97 return _run_special_i0e_kernel(x, out)
100def special_i0e_out(x: torch.Tensor, out: torch.Tensor):
101 """
102 ATen wrapper: special_i0e.out(Tensor self, Tensor out) -> Tensor
103 """
104 # Broadcast input to out's shape if needed
105 if x.shape != out.shape:
106 x = x.expand(out.shape)
107 _run_special_i0e_kernel(x, out)
108 return out