Coverage for src/flag_gems/ops/i0_.py: 46%
46 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
4import torch
5import triton
6import triton.language as tl
8logger = logging.getLogger(__name__)
11@triton.jit
12def i0_kernel_(x_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
13 pid = tl.program_id(axis=0)
14 block_start = pid * BLOCK_SIZE
15 offsets = block_start + tl.arange(0, BLOCK_SIZE)
16 mask = offsets < n_elements
18 x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
19 xf = tl.cast(x, tl.float32)
20 ax = tl.abs(xf)
22 t_small = ax / 3.75
23 y_small = t_small * t_small
24 poly_small = 1.0 + y_small * (
25 3.5156229
26 + y_small
27 * (
28 3.0899424
29 + y_small
30 * (
31 1.2067492
32 + y_small * (0.2659732 + y_small * (0.0360768 + y_small * 0.0045813))
33 )
34 )
35 )
37 y_large = 3.75 / ax
38 poly_large = 0.39894228 + y_large * (
39 0.01328592
40 + y_large
41 * (
42 0.00225319
43 + y_large
44 * (
45 -0.00157565
46 + y_large
47 * (
48 0.00916281
49 + y_large
50 * (
51 -0.02057706
52 + y_large
53 * (0.02635537 + y_large * (-0.01647633 + y_large * 0.00392377))
54 )
55 )
56 )
57 )
58 )
59 val_large = tl.exp(ax) * poly_large / tl.sqrt(ax)
61 result = tl.where(ax <= 3.75, poly_small, val_large)
63 result_cast = tl.cast(result, x.dtype)
64 tl.store(x_ptr + offsets, result_cast, mask=mask)
67def i0_(*args, **kwargs):
68 logger.debug("GEMS I0_")
69 x = None
70 if len(args) > 0:
71 x = args[0]
72 else:
73 # Try common keyword names
74 for k in ("input", "self", "x"):
75 if k in kwargs:
76 x = kwargs[k]
77 break
78 if x is None:
79 raise ValueError(
80 "i0_ expects a tensor as the first positional argument or in keyword 'input'/'self'/'x'."
81 )
83 if not x.is_cuda:
84 raise AssertionError("Input tensor must be on a CUDA device.")
85 if not x.is_contiguous():
86 raise AssertionError("Input tensor must be contiguous.")
87 if x.dtype not in (torch.float16, torch.bfloat16, torch.float32, torch.float64):
88 raise AssertionError(
89 "Unsupported dtype for i0_. Supported: float16, bfloat16, float32, float64."
90 )
92 n_elements = x.numel()
93 if n_elements == 0:
94 return x
96 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
97 i0_kernel_[grid](x, n_elements, BLOCK_SIZE=1024)
98 return x