Coverage for src/flag_gems/experimental_ops/erf_.py: 0%

43 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def erf_(x_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 

13 x = tl.load(x_ptr + offsets, mask=mask) 

14 x32 = x.to(tl.float32) 

15 

16 ax = tl.abs(x32) 

17 t = 1.0 / (1.0 + 0.5 * ax) 

18 

19 p = 1.00002368 + t * ( 

20 0.37409196 

21 + t 

22 * ( 

23 0.09678418 

24 + t 

25 * ( 

26 -0.18628806 

27 + t 

28 * ( 

29 0.27886807 

30 + t 

31 * ( 

32 -1.13520398 

33 + t * (1.48851587 + t * (-0.82215223 + t * 0.17087277)) 

34 ) 

35 ) 

36 ) 

37 ) 

38 ) 

39 s = -x32 * x32 - 1.26551223 + t * p 

40 tau = t * tl.exp(s) 

41 y32 = tl.where(x32 >= 0, 1.0 - tau, tau - 1.0) 

42 

43 y = y32.to(x.dtype) 

44 tl.store(x_ptr + offsets, y, mask=mask) 

45 

46 

47# keep a reference to the kernel before defining the wrapper with the same name 

48erf__kernel = erf_ 

49 

50 

51def erf_(*args, **kwargs): 

52 # Extract the input tensor 

53 x = None 

54 if len(args) >= 1 and isinstance(args[0], torch.Tensor): 

55 x = args[0] 

56 elif "input" in kwargs and isinstance(kwargs["input"], torch.Tensor): 

57 x = kwargs["input"] 

58 elif "self" in kwargs and isinstance(kwargs["self"], torch.Tensor): 

59 x = kwargs["self"] 

60 elif ( 

61 "args" in kwargs 

62 and isinstance(kwargs["args"], (list, tuple)) 

63 and kwargs["args"] 

64 ): 

65 x = kwargs["args"][0] 

66 if x is None: 

67 raise TypeError("erf_ expects a tensor as its first argument") 

68 

69 # Fallback for unsupported devices/dtypes 

70 if not x.is_cuda: 

71 return x.erf_() 

72 

73 if x.dtype not in (torch.float16, torch.bfloat16, torch.float32): 

74 return x.erf_() 

75 

76 n_elements = x.numel() 

77 if n_elements == 0: 

78 return x 

79 

80 BLOCK_SIZE = 1024 

81 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

82 erf__kernel[grid](x, n_elements, BLOCK_SIZE=BLOCK_SIZE) 

83 return x