Coverage for src/flag_gems/experimental_ops/relu_.py: 0%
28 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1import torch # noqa: F401
2import triton
3import triton.language as tl
6@triton.jit
7def relu_(
8 x_ptr, # *Pointer* to input/output tensor (in-place).
9 n_elements, # Number of elements.
10 BLOCK_SIZE: tl.constexpr,
11):
12 pid = tl.program_id(axis=0)
13 block_start = pid * BLOCK_SIZE
14 offsets = block_start + tl.arange(0, BLOCK_SIZE)
15 mask = offsets < n_elements
17 x = tl.load(x_ptr + offsets, mask=mask)
18 zero = x * 0
19 y = tl.where(x > 0, x, zero)
20 tl.store(x_ptr + offsets, y, mask=mask)
23# Keep a reference to the Triton kernel before defining the Python wrapper with the same name.
24relu__kernel = relu_
27def relu_(*args, **kwargs):
28 # Expect the first positional argument to be the tensor.
29 x = args[0] if len(args) > 0 else kwargs.get("input", kwargs.get("x", None))
30 if x is None:
31 raise ValueError("relu_ expects a tensor as the first positional argument.")
32 if not x.is_cuda:
33 raise ValueError("relu_ Triton implementation requires a CUDA tensor.")
34 if not x.is_contiguous():
35 raise ValueError("relu_ Triton implementation requires a contiguous tensor.")
37 n_elements = x.numel()
38 if n_elements == 0:
39 return x
41 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
42 relu__kernel[grid](x, n_elements, BLOCK_SIZE=1024)
43 return x