Coverage for src/flag_gems/experimental_ops/elu.py: 0%
68 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def elu_kernel(
8 x_ptr, out_ptr, n_elements, alpha, scale, input_scale, BLOCK_SIZE: tl.constexpr
9):
10 pid = tl.program_id(axis=0)
11 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
12 mask = offsets < n_elements
14 x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
15 x32 = x.to(tl.float32)
17 pos = x32 > 0.0
18 neg = alpha * (tl.exp(input_scale * x32) - 1.0)
19 y32 = tl.where(pos, x32, neg)
20 y32 = scale * y32
22 y = y32.to(x.dtype)
23 tl.store(out_ptr + offsets, y, mask=mask)
26def _parse_elu_args(args, kwargs, expect_out: bool = False):
27 x = None
28 if len(args) > 0 and isinstance(args[0], torch.Tensor):
29 x = args[0]
30 arg_idx = 1
31 else:
32 x = kwargs.get("input", kwargs.get("self", kwargs.get("x", None)))
33 arg_idx = 0
35 if x is None:
36 raise ValueError("elu expects a Tensor as the first argument (input/self/x).")
38 def _get_scalar(name, default, idx):
39 if name in kwargs:
40 return float(kwargs[name])
41 elif len(args) > idx:
42 return float(args[idx])
43 else:
44 return float(default)
46 alpha = _get_scalar("alpha", 1.0, arg_idx + 0)
47 scale = _get_scalar("scale", 1.0, arg_idx + 1)
48 input_scale = _get_scalar("input_scale", 1.0, arg_idx + 2)
50 out = None
51 if expect_out:
52 if "out" in kwargs and isinstance(kwargs["out"], torch.Tensor):
53 out = kwargs["out"]
54 elif len(args) > arg_idx + 3 and isinstance(args[arg_idx + 3], torch.Tensor):
55 out = args[arg_idx + 3]
56 elif len(args) > arg_idx + 4 and isinstance(args[arg_idx + 4], torch.Tensor):
57 out = args[arg_idx + 4]
58 else:
59 raise ValueError("elu_out expects an 'out' tensor argument.")
61 return x, alpha, scale, input_scale, out
64def _launch_elu_kernel(
65 x: torch.Tensor, out: torch.Tensor, alpha: float, scale: float, input_scale: float
66):
67 if not x.is_cuda or not out.is_cuda:
68 raise RuntimeError("elu Triton kernel requires CUDA tensors.")
69 if x.numel() != out.numel():
70 raise ValueError("Input and output must have the same number of elements.")
71 if x.dtype != out.dtype:
72 raise ValueError("Input and output must have the same dtype.")
73 if not x.is_contiguous() or not out.is_contiguous():
74 raise ValueError("Input and output must be contiguous tensors.")
76 n_elements = x.numel()
77 BLOCK_SIZE = 1024
78 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
79 elu_kernel[grid](
80 x,
81 out,
82 n_elements,
83 float(alpha),
84 float(scale),
85 float(input_scale),
86 BLOCK_SIZE=BLOCK_SIZE,
87 )
90def elu(*args, **kwargs):
91 x, alpha, scale, input_scale, _ = _parse_elu_args(args, kwargs, expect_out=False)
92 out = torch.empty_like(x)
93 _launch_elu_kernel(x.contiguous(), out, alpha, scale, input_scale)
94 return out
97def elu_out(*args, **kwargs):
98 x, alpha, scale, input_scale, out = _parse_elu_args(args, kwargs, expect_out=True)
99 if out is None:
100 raise ValueError("elu_out requires an 'out' tensor.")
101 _launch_elu_kernel(x.contiguous(), out, alpha, scale, input_scale)
102 return out