Coverage for src/flag_gems/ops/special_i1.py: 36%
74 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
10logger = logging.getLogger(__name__)
13@triton.jit
14def special_i1_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
15 pid = tl.program_id(axis=0)
16 block_start = pid * BLOCK_SIZE
17 offsets = block_start + tl.arange(0, BLOCK_SIZE)
18 mask = offsets < n_elements
20 x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
21 x_f32 = x.to(tl.float32)
22 ax = tl.abs(x_f32)
24 # Small region: |x| <= 3.75
25 y = x_f32 / 3.75
26 y2 = y * y
27 # Horner polynomial for small |x|
28 p = 0.00032411
29 p = 0.00301532 + y2 * p
30 p = 0.02658733 + y2 * p
31 p = 0.15084934 + y2 * p
32 p = 0.51498869 + y2 * p
33 p = 0.87890594 + y2 * p
34 p = 0.5 + y2 * p
35 ans_small = x_f32 * p
37 # Large region: |x| > 3.75
38 # Use asymptotic expansion: I1(x) ~ exp(|x|)/sqrt(|x|) * poly(3.75/|x|)
39 # Coefficients from Cephes
40 t = 3.75 / tl.maximum(ax, 1e-20)
41 q = -0.00420059
42 q = 0.01787654 + t * q
43 q = -0.02895312 + t * q
44 q = 0.02282967 + t * q
45 q = -0.01031555 + t * q
46 q = 0.00163801 + t * q
47 q = -0.00362018 + t * q
48 q = -0.03988024 + t * q
49 q = 0.39894228 + t * q
50 pref = tl.exp(ax) / tl.sqrt(tl.maximum(ax, 1e-20))
51 ans_large = pref * q
52 # I1 is odd
53 ans_large = tl.where(x_f32 < 0, -ans_large, ans_large)
55 is_small = ax <= 3.75
56 ans = tl.where(is_small, ans_small, ans_large)
58 # Cast back to input dtype and store
59 tl.store(out_ptr + offsets, ans.to(x.dtype), mask=mask)
62def _launch_special_i1(x: torch.Tensor, out: torch.Tensor):
63 assert x.is_cuda and out.is_cuda, "Tensors must be CUDA tensors"
64 assert (
65 x.numel() == out.numel()
66 ), "Input and output must have the same number of elements"
67 assert x.dtype == out.dtype, "Input and output must have the same dtype"
69 n_elements = x.numel()
70 if n_elements == 0:
71 return
73 BLOCK_SIZE = 1024
74 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
75 with torch_device_fn.device(x.device):
76 special_i1_kernel[grid](x, out, n_elements, BLOCK_SIZE=BLOCK_SIZE)
79def special_i1(self: torch.Tensor):
80 logger.debug("GEMS SPECIAL_I1")
81 x = self
82 x_c = x.contiguous()
83 out = torch.empty_like(x_c)
84 _launch_special_i1(x_c, out)
85 if x.layout == torch.strided and x.is_contiguous():
86 return out
87 else:
88 return out.view_as(x)
91def special_i1_out(self: torch.Tensor, out: torch.Tensor):
92 logger.debug("GEMS SPECIAL_I1_OUT")
93 x = self
94 if out.dtype != x.dtype:
95 raise TypeError("out dtype must match input dtype")
96 if out.device != x.device:
97 raise TypeError("out device must match input device")
99 x_c = x.contiguous()
100 out_c = out.contiguous()
101 _launch_special_i1(x_c, out_c)
102 if out_c.data_ptr() != out.data_ptr():
103 out.copy_(out_c)
104 return out