Coverage for src/flag_gems/experimental_ops/special_i0e.py: 0%
45 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def _special_i0e_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
13 x = tl.load(x_ptr + offsets, mask=mask)
15 # Compute in fp32 for accuracy/stability
16 xf = x.to(tl.float32)
17 ax = tl.abs(xf)
19 # Small region: x <= 3.75
20 t_small = ax / 3.75
21 t2 = t_small * t_small
22 # Polynomial approximation for I0 in small region (Numerical Recipes)
23 p = 1.0 + t2 * (
24 3.5156229
25 + t2
26 * (
27 3.0899424
28 + t2 * (1.2067492 + t2 * (0.2659732 + t2 * (0.0360768 + t2 * 0.0045813)))
29 )
30 )
31 small = p * tl.exp(-ax)
33 # Large region: x > 3.75, use asymptotic expansion to avoid exp overflow
34 # i0e(x) = I0(x)*exp(-|x|) ≈ (1/sqrt(|x|)) * poly(3.75/|x|)
35 t = 3.75 / ax
36 q = 0.39894228 + t * (
37 0.01328592
38 + t
39 * (
40 0.00225319
41 + t
42 * (
43 -0.00157565
44 + t
45 * (
46 0.00916281
47 + t
48 * (
49 -0.02057706
50 + t * (0.02635537 + t * (-0.01647633 + t * 0.00392377))
51 )
52 )
53 )
54 )
55 )
56 large = q / tl.sqrt(ax)
58 is_large = ax > 3.75
59 y = tl.where(is_large, large, small)
61 # Cast back to input dtype for storage
62 y = y.to(x.dtype)
63 tl.store(out_ptr + offsets, y, mask=mask)
66def _run_special_i0e_kernel(x: torch.Tensor, out: torch.Tensor):
67 assert x.is_cuda and out.is_cuda, "Tensors must be CUDA tensors"
68 assert x.dtype in (
69 torch.float16,
70 torch.bfloat16,
71 torch.float32,
72 torch.float64,
73 ), "Unsupported dtype"
74 assert out.dtype == x.dtype, "Output dtype must match input dtype"
76 x_c = x.contiguous()
77 out_c = out.contiguous()
79 n_elements = out_c.numel()
80 if n_elements == 0:
81 return out
83 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
84 _special_i0e_kernel[grid](x_c, out_c, n_elements, BLOCK_SIZE=1024)
86 if out_c.data_ptr() != out.data_ptr():
87 out.copy_(out_c)
88 return out
91def special_i0e(x: torch.Tensor):
92 """
93 ATen wrapper: special_i0e(Tensor self) -> Tensor
94 """
95 out = torch.empty_like(x)
96 return _run_special_i0e_kernel(x, out)
99def special_i0e_out(x: torch.Tensor, out: torch.Tensor):
100 """
101 ATen wrapper: special_i0e.out(Tensor self, Tensor out) -> Tensor
102 """
103 # Broadcast input to out's shape if needed
104 if x.shape != out.shape:
105 x = x.expand(out.shape)
106 _run_special_i0e_kernel(x, out)
107 return out