Coverage for src/flag_gems/experimental_ops/erf_.py: 0%
43 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def erf_(x_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
13 x = tl.load(x_ptr + offsets, mask=mask)
14 x32 = x.to(tl.float32)
16 ax = tl.abs(x32)
17 t = 1.0 / (1.0 + 0.5 * ax)
19 p = 1.00002368 + t * (
20 0.37409196
21 + t
22 * (
23 0.09678418
24 + t
25 * (
26 -0.18628806
27 + t
28 * (
29 0.27886807
30 + t
31 * (
32 -1.13520398
33 + t * (1.48851587 + t * (-0.82215223 + t * 0.17087277))
34 )
35 )
36 )
37 )
38 )
39 s = -x32 * x32 - 1.26551223 + t * p
40 tau = t * tl.exp(s)
41 y32 = tl.where(x32 >= 0, 1.0 - tau, tau - 1.0)
43 y = y32.to(x.dtype)
44 tl.store(x_ptr + offsets, y, mask=mask)
47# keep a reference to the kernel before defining the wrapper with the same name
48erf__kernel = erf_
51def erf_(*args, **kwargs):
52 # Extract the input tensor
53 x = None
54 if len(args) >= 1 and isinstance(args[0], torch.Tensor):
55 x = args[0]
56 elif "input" in kwargs and isinstance(kwargs["input"], torch.Tensor):
57 x = kwargs["input"]
58 elif "self" in kwargs and isinstance(kwargs["self"], torch.Tensor):
59 x = kwargs["self"]
60 elif (
61 "args" in kwargs
62 and isinstance(kwargs["args"], (list, tuple))
63 and kwargs["args"]
64 ):
65 x = kwargs["args"][0]
66 if x is None:
67 raise TypeError("erf_ expects a tensor as its first argument")
69 # Fallback for unsupported devices/dtypes
70 if not x.is_cuda:
71 return x.erf_()
73 if x.dtype not in (torch.float16, torch.bfloat16, torch.float32):
74 return x.erf_()
76 n_elements = x.numel()
77 if n_elements == 0:
78 return x
80 BLOCK_SIZE = 1024
81 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
82 erf__kernel[grid](x, n_elements, BLOCK_SIZE=BLOCK_SIZE)
83 return x