Coverage for src/flag_gems/experimental_ops/erfinv_.py: 0%

46 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def erfinv_(x_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 

13 # Load input and compute in fp32 for better precision 

14 y = tl.load(x_ptr + offsets, mask=mask, other=0) 

15 y32 = tl.cast(y, tl.float32) 

16 

17 # Constants 

18 a = 0.147 

19 inv_a = 1.0 / a 

20 PI = 3.14159265358979323846 

21 SQRT_PI = 1.77245385090551602729 

22 

23 # Winitzki approximation for initial guess 

24 # w = ln(1 - y^2) 

25 w = tl.log(1.0 - y32 * y32) 

26 b = (2.0 / (PI * a)) + 0.5 * w 

27 # inner = sqrt(b^2 - w/a) 

28 inner = tl.sqrt(b * b - w * inv_a) 

29 # sign(y) 

30 s = tl.where(y32 >= 0.0, 1.0, -1.0) 

31 x0 = s * tl.sqrt(inner - b) 

32 

33 # Valid mask for refinement: -1 < y < 1 

34 valid = (y32 > -1.0) & (y32 < 1.0) 

35 

36 # Newton refinement using an erf approximation (Abramowitz-Stegun 7.1.26) 

37 # Perform two iterations 

38 for _ in range(2): 

39 z = x0 

40 

41 absz = tl.abs(z) 

42 t = 1.0 / (1.0 + 0.3275911 * absz) 

43 # Polynomial for approximation 

44 poly = ( 

45 ((((1.061405429 * t) - 1.453152027) * t) + 1.421413741) * t - 0.284496736 

46 ) * t + 0.254829592 

47 erf_abs = 1.0 - poly * t * tl.exp(-absz * absz) 

48 erf_z = tl.where(z >= 0.0, erf_abs, -erf_abs) 

49 

50 derivative = (2.0 / SQRT_PI) * tl.exp(-(z * z)) 

51 step = (erf_z - y32) / derivative 

52 x0 = tl.where(valid, z - step, z) 

53 

54 # Store result back in-place 

55 tl.store(x_ptr + offsets, x0, mask=mask) 

56 

57 

58_erfinv_kernel = erfinv_ 

59 

60 

61def erfinv_(*args, **kwargs): 

62 # Expect a single tensor input 

63 x = args[0] if len(args) > 0 else kwargs.get("input", None) 

64 if x is None: 

65 raise ValueError("erfinv_ expects a tensor as the first argument.") 

66 # Fallback for unsupported cases 

67 if (not x.is_cuda) or (x.dtype == torch.float64) or (not x.is_contiguous()): 

68 x.copy_(torch.special.erfinv(x)) 

69 return x 

70 

71 if x.dtype not in (torch.float16, torch.bfloat16, torch.float32): 

72 raise TypeError( 

73 "erfinv_ Triton kernel supports float16, bfloat16, and float32 tensors." 

74 ) 

75 

76 n_elements = x.numel() 

77 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

78 _erfinv_kernel[grid](x, n_elements, BLOCK_SIZE=1024) 

79 return x