Coverage for src/flag_gems/experimental_ops/elu.py: 0%

68 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def elu_kernel( 

8 x_ptr, out_ptr, n_elements, alpha, scale, input_scale, BLOCK_SIZE: tl.constexpr 

9): 

10 pid = tl.program_id(axis=0) 

11 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

12 mask = offsets < n_elements 

13 

14 x = tl.load(x_ptr + offsets, mask=mask, other=0.0) 

15 x32 = x.to(tl.float32) 

16 

17 pos = x32 > 0.0 

18 neg = alpha * (tl.exp(input_scale * x32) - 1.0) 

19 y32 = tl.where(pos, x32, neg) 

20 y32 = scale * y32 

21 

22 y = y32.to(x.dtype) 

23 tl.store(out_ptr + offsets, y, mask=mask) 

24 

25 

26def _parse_elu_args(args, kwargs, expect_out: bool = False): 

27 x = None 

28 if len(args) > 0 and isinstance(args[0], torch.Tensor): 

29 x = args[0] 

30 arg_idx = 1 

31 else: 

32 x = kwargs.get("input", kwargs.get("self", kwargs.get("x", None))) 

33 arg_idx = 0 

34 

35 if x is None: 

36 raise ValueError("elu expects a Tensor as the first argument (input/self/x).") 

37 

38 def _get_scalar(name, default, idx): 

39 if name in kwargs: 

40 return float(kwargs[name]) 

41 elif len(args) > idx: 

42 return float(args[idx]) 

43 else: 

44 return float(default) 

45 

46 alpha = _get_scalar("alpha", 1.0, arg_idx + 0) 

47 scale = _get_scalar("scale", 1.0, arg_idx + 1) 

48 input_scale = _get_scalar("input_scale", 1.0, arg_idx + 2) 

49 

50 out = None 

51 if expect_out: 

52 if "out" in kwargs and isinstance(kwargs["out"], torch.Tensor): 

53 out = kwargs["out"] 

54 elif len(args) > arg_idx + 3 and isinstance(args[arg_idx + 3], torch.Tensor): 

55 out = args[arg_idx + 3] 

56 elif len(args) > arg_idx + 4 and isinstance(args[arg_idx + 4], torch.Tensor): 

57 out = args[arg_idx + 4] 

58 else: 

59 raise ValueError("elu_out expects an 'out' tensor argument.") 

60 

61 return x, alpha, scale, input_scale, out 

62 

63 

64def _launch_elu_kernel( 

65 x: torch.Tensor, out: torch.Tensor, alpha: float, scale: float, input_scale: float 

66): 

67 if not x.is_cuda or not out.is_cuda: 

68 raise RuntimeError("elu Triton kernel requires CUDA tensors.") 

69 if x.numel() != out.numel(): 

70 raise ValueError("Input and output must have the same number of elements.") 

71 if x.dtype != out.dtype: 

72 raise ValueError("Input and output must have the same dtype.") 

73 if not x.is_contiguous() or not out.is_contiguous(): 

74 raise ValueError("Input and output must be contiguous tensors.") 

75 

76 n_elements = x.numel() 

77 BLOCK_SIZE = 1024 

78 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

79 elu_kernel[grid]( 

80 x, 

81 out, 

82 n_elements, 

83 float(alpha), 

84 float(scale), 

85 float(input_scale), 

86 BLOCK_SIZE=BLOCK_SIZE, 

87 ) 

88 

89 

90def elu(*args, **kwargs): 

91 x, alpha, scale, input_scale, _ = _parse_elu_args(args, kwargs, expect_out=False) 

92 out = torch.empty_like(x) 

93 _launch_elu_kernel(x.contiguous(), out, alpha, scale, input_scale) 

94 return out 

95 

96 

97def elu_out(*args, **kwargs): 

98 x, alpha, scale, input_scale, out = _parse_elu_args(args, kwargs, expect_out=True) 

99 if out is None: 

100 raise ValueError("elu_out requires an 'out' tensor.") 

101 _launch_elu_kernel(x.contiguous(), out, alpha, scale, input_scale) 

102 return out