Coverage for src/flag_gems/experimental_ops/erfinv.py: 0%
78 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def erfinv_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
13 x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
14 xf = x.to(tl.float32)
16 one = 1.0
17 absx = tl.abs(xf)
18 w = -tl.log((one - xf) * (one + xf))
20 use_low = w < 5.0
22 wl = w - 2.5
23 pl = 2.81022636e-08
24 pl = 3.43273939e-07 + pl * wl
25 pl = -3.5233877e-06 + pl * wl
26 pl = -4.39150654e-06 + pl * wl
27 pl = 2.1858087e-04 + pl * wl
28 pl = -1.25372503e-03 + pl * wl
29 pl = -4.17768164e-03 + pl * wl
30 pl = 2.46640727e-01 + pl * wl
31 pl = 1.50140941e00 + pl * wl
33 wh = tl.sqrt(w) - 3.0
34 ph = -2.00214257e-04
35 ph = 1.00950558e-04 + ph * wh
36 ph = 1.34934322e-03 + ph * wh
37 ph = -3.67342844e-03 + ph * wh
38 ph = 5.73950773e-03 + ph * wh
39 ph = -7.62246130e-03 + ph * wh
40 ph = 9.43887047e-03 + ph * wh
41 ph = 1.00167406e00 + ph * wh
42 ph = 2.83297682e00 + ph * wh
44 p = tl.where(use_low, pl, ph)
45 res = p * xf
47 nan_vec = tl.full([BLOCK_SIZE], float("nan"), dtype=tl.float32)
48 inf_vec = tl.full([BLOCK_SIZE], float("inf"), dtype=tl.float32)
50 mask_nan = xf != xf
51 mask_oob = absx > 1.0
52 mask_pos1 = xf == 1.0
53 mask_neg1 = xf == -1.0
55 res = tl.where(mask_nan, nan_vec, res)
56 res = tl.where(mask_oob, nan_vec, res)
57 res = tl.where(mask_pos1, inf_vec, res)
58 res = tl.where(mask_neg1, -inf_vec, res)
60 y = res.to(x.dtype)
61 tl.store(out_ptr + offsets, y, mask=mask)
64def _launch_erfinv_kernel(x: torch.Tensor, out: torch.Tensor):
65 assert x.is_cuda and out.is_cuda, "Inputs must be CUDA tensors"
66 assert (
67 x.numel() == out.numel()
68 ), "Input and output must have the same number of elements"
69 assert x.dtype == out.dtype, "Input and output must have the same dtype"
70 n_elements = x.numel()
71 BLOCK_SIZE = 1024
72 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
73 erfinv_kernel[grid](
74 x,
75 out,
76 n_elements,
77 BLOCK_SIZE=BLOCK_SIZE,
78 )
81def erfinv(x: torch.Tensor):
82 x_in = x
83 if not x_in.is_contiguous():
84 x_in = x_in.contiguous()
85 out = torch.empty_like(x_in)
86 _launch_erfinv_kernel(x_in, out)
87 # Match original shape/strides of input if needed
88 if out.shape != x.shape or out.stride() != x.stride():
89 out = out.reshape(x.shape).as_strided(x.size(), x.stride())
90 return out
93def erfinv_out(x: torch.Tensor, out: torch.Tensor):
94 # Resize out to match input shape if necessary
95 if out.shape != x.shape:
96 out.resize_(x.shape)
97 # Ensure dtype matches input dtype for aten out semantics
98 assert out.dtype == x.dtype, "out tensor must have the same dtype as input"
99 x_in = x if x.is_contiguous() else x.contiguous()
100 if out.is_contiguous():
101 _launch_erfinv_kernel(x_in, out)
102 return out
103 else:
104 tmp = torch.empty_like(out, memory_format=torch.contiguous_format)
105 _launch_erfinv_kernel(x_in, tmp)
106 out.copy_(tmp)
107 return out