Coverage for src/flag_gems/experimental_ops/i0_.py: 0%
44 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def i0_(x_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
13 x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
14 xf = tl.cast(x, tl.float32)
15 ax = tl.abs(xf)
17 t_small = ax / 3.75
18 y_small = t_small * t_small
19 poly_small = 1.0 + y_small * (
20 3.5156229
21 + y_small
22 * (
23 3.0899424
24 + y_small
25 * (
26 1.2067492
27 + y_small * (0.2659732 + y_small * (0.0360768 + y_small * 0.0045813))
28 )
29 )
30 )
32 y_large = 3.75 / ax
33 poly_large = 0.39894228 + y_large * (
34 0.01328592
35 + y_large
36 * (
37 0.00225319
38 + y_large
39 * (
40 -0.00157565
41 + y_large
42 * (
43 0.00916281
44 + y_large
45 * (
46 -0.02057706
47 + y_large
48 * (0.02635537 + y_large * (-0.01647633 + y_large * 0.00392377))
49 )
50 )
51 )
52 )
53 )
54 val_large = tl.exp(ax) * poly_large / tl.sqrt(ax)
56 result = tl.where(ax <= 3.75, poly_small, val_large)
58 result_cast = tl.cast(result, x.dtype)
59 tl.store(x_ptr + offsets, result_cast, mask=mask)
62# Keep a reference to the Triton kernel before defining the Python wrapper with the same name
63i0__kernel = i0_
66def i0_(*args, **kwargs):
67 x = None
68 if len(args) > 0:
69 x = args[0]
70 else:
71 # Try common keyword names
72 for k in ("input", "self", "x"):
73 if k in kwargs:
74 x = kwargs[k]
75 break
76 if x is None:
77 raise ValueError(
78 "i0_ expects a tensor as the first positional argument or in keyword 'input'/'self'/'x'."
79 )
81 if not x.is_cuda:
82 raise AssertionError("Input tensor must be on a CUDA device.")
83 if not x.is_contiguous():
84 raise AssertionError("Input tensor must be contiguous.")
85 if x.dtype not in (torch.float16, torch.bfloat16, torch.float32, torch.float64):
86 raise AssertionError(
87 "Unsupported dtype for i0_. Supported: float16, bfloat16, float32, float64."
88 )
90 n_elements = x.numel()
91 if n_elements == 0:
92 return x
94 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
95 i0__kernel[grid](x, n_elements, BLOCK_SIZE=1024)
96 return x