Coverage for src/flag_gems/experimental_ops/special_i1.py: 0%
68 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def special_i1_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
13 x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
14 x_f32 = x.to(tl.float32)
15 ax = tl.abs(x_f32)
17 # Small region: |x| <= 3.75
18 y = x_f32 / 3.75
19 y2 = y * y
20 # Horner polynomial for small |x|
21 p = 0.00032411
22 p = 0.00301532 + y2 * p
23 p = 0.02658733 + y2 * p
24 p = 0.15084934 + y2 * p
25 p = 0.51498869 + y2 * p
26 p = 0.87890594 + y2 * p
27 p = 0.5 + y2 * p
28 ans_small = x_f32 * p
30 # Large region: |x| > 3.75
31 # Use asymptotic expansion: I1(x) ~ exp(|x|)/sqrt(|x|) * poly(3.75/|x|)
32 # Coefficients from Cephes
33 t = 3.75 / tl.maximum(ax, 1e-20)
34 q = -0.00420059
35 q = 0.01787654 + t * q
36 q = -0.02895312 + t * q
37 q = 0.02282967 + t * q
38 q = -0.01031555 + t * q
39 q = 0.00163801 + t * q
40 q = -0.00362018 + t * q
41 q = -0.03988024 + t * q
42 q = 0.39894228 + t * q
43 pref = tl.exp(ax) / tl.sqrt(tl.maximum(ax, 1e-20))
44 ans_large = pref * q
45 # I1 is odd
46 ans_large = tl.where(x_f32 < 0, -ans_large, ans_large)
48 is_small = ax <= 3.75
49 ans = tl.where(is_small, ans_small, ans_large)
51 # Cast back to input dtype and store
52 tl.store(out_ptr + offsets, ans.to(x.dtype), mask=mask)
55def _launch_special_i1(x: torch.Tensor, out: torch.Tensor):
56 assert x.is_cuda and out.is_cuda, "Tensors must be CUDA tensors"
57 assert (
58 x.numel() == out.numel()
59 ), "Input and output must have the same number of elements"
60 assert x.dtype == out.dtype, "Input and output must have the same dtype"
62 n_elements = x.numel()
63 if n_elements == 0:
64 return
66 BLOCK_SIZE = 1024
67 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
68 special_i1_kernel[grid](x, out, n_elements, BLOCK_SIZE=BLOCK_SIZE)
71def special_i1(self: torch.Tensor):
72 x = self
73 x_c = x.contiguous()
74 out = torch.empty_like(x_c)
75 _launch_special_i1(x_c, out)
76 # If original was non-contiguous, return view with same shape
77 if x.layout == torch.strided and x.is_contiguous():
78 return out
79 else:
80 return out.view_as(x)
83def special_i1_out(self: torch.Tensor, out: torch.Tensor):
84 x = self
85 # Ensure dtypes and devices match expectations
86 if out.dtype != x.dtype:
87 raise TypeError("out dtype must match input dtype")
88 if out.device != x.device:
89 raise TypeError("out device must match input device")
91 x_c = x.contiguous()
92 out_c = out.contiguous()
93 _launch_special_i1(x_c, out_c)
94 if out_c.data_ptr() != out.data_ptr():
95 out.copy_(out_c)
96 return out