Coverage for src/flag_gems/experimental_ops/erfinv_.py: 0%
46 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def erfinv_(x_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
13 # Load input and compute in fp32 for better precision
14 y = tl.load(x_ptr + offsets, mask=mask, other=0)
15 y32 = tl.cast(y, tl.float32)
17 # Constants
18 a = 0.147
19 inv_a = 1.0 / a
20 PI = 3.14159265358979323846
21 SQRT_PI = 1.77245385090551602729
23 # Winitzki approximation for initial guess
24 # w = ln(1 - y^2)
25 w = tl.log(1.0 - y32 * y32)
26 b = (2.0 / (PI * a)) + 0.5 * w
27 # inner = sqrt(b^2 - w/a)
28 inner = tl.sqrt(b * b - w * inv_a)
29 # sign(y)
30 s = tl.where(y32 >= 0.0, 1.0, -1.0)
31 x0 = s * tl.sqrt(inner - b)
33 # Valid mask for refinement: -1 < y < 1
34 valid = (y32 > -1.0) & (y32 < 1.0)
36 # Newton refinement using an erf approximation (Abramowitz-Stegun 7.1.26)
37 # Perform two iterations
38 for _ in range(2):
39 z = x0
41 absz = tl.abs(z)
42 t = 1.0 / (1.0 + 0.3275911 * absz)
43 # Polynomial for approximation
44 poly = (
45 ((((1.061405429 * t) - 1.453152027) * t) + 1.421413741) * t - 0.284496736
46 ) * t + 0.254829592
47 erf_abs = 1.0 - poly * t * tl.exp(-absz * absz)
48 erf_z = tl.where(z >= 0.0, erf_abs, -erf_abs)
50 derivative = (2.0 / SQRT_PI) * tl.exp(-(z * z))
51 step = (erf_z - y32) / derivative
52 x0 = tl.where(valid, z - step, z)
54 # Store result back in-place
55 tl.store(x_ptr + offsets, x0, mask=mask)
58_erfinv_kernel = erfinv_
61def erfinv_(*args, **kwargs):
62 # Expect a single tensor input
63 x = args[0] if len(args) > 0 else kwargs.get("input", None)
64 if x is None:
65 raise ValueError("erfinv_ expects a tensor as the first argument.")
66 # Fallback for unsupported cases
67 if (not x.is_cuda) or (x.dtype == torch.float64) or (not x.is_contiguous()):
68 x.copy_(torch.special.erfinv(x))
69 return x
71 if x.dtype not in (torch.float16, torch.bfloat16, torch.float32):
72 raise TypeError(
73 "erfinv_ Triton kernel supports float16, bfloat16, and float32 tensors."
74 )
76 n_elements = x.numel()
77 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
78 _erfinv_kernel[grid](x, n_elements, BLOCK_SIZE=1024)
79 return x