Coverage for src/flag_gems/experimental_ops/leaky_relu.py: 0%
36 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def _leaky_relu_kernel(
8 x_ptr, y_ptr, n_elements, negative_slope, BLOCK_SIZE: tl.constexpr
9):
10 pid = tl.program_id(axis=0)
11 block_start = pid * BLOCK_SIZE
12 offsets = block_start + tl.arange(0, BLOCK_SIZE)
13 mask = offsets < n_elements
15 x = tl.load(x_ptr + offsets, mask=mask, other=0)
16 zero = tl.zeros([BLOCK_SIZE], dtype=x.dtype)
17 slope = tl.full([BLOCK_SIZE], negative_slope, dtype=x.dtype)
18 # y = x if x >= 0 else slope * x
19 # Equivalent, branchless:
20 y = tl.maximum(x, zero) + slope * tl.minimum(x, zero)
22 tl.store(y_ptr + offsets, y, mask=mask)
25def _launch_leaky_relu_kernel(
26 x: torch.Tensor, out: torch.Tensor, negative_slope: float
27):
28 if not x.is_cuda or not out.is_cuda:
29 raise ValueError("Input and output tensors must be on CUDA device.")
30 if x.numel() != out.numel():
31 raise ValueError("Input and output must have the same number of elements.")
32 if x.dtype != out.dtype:
33 raise ValueError("Input and output tensors must have the same dtype.")
34 if not x.is_contiguous():
35 x = x.contiguous()
36 if not out.is_contiguous():
37 raise ValueError("Output tensor must be contiguous.")
39 n_elements = x.numel()
40 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
41 _leaky_relu_kernel[grid](x, out, n_elements, float(negative_slope), BLOCK_SIZE=1024)
42 return out
45def leaky_relu(input: torch.Tensor, negative_slope: float = 0.01):
46 """
47 ATen: ('leaky_relu', <Autograd.disable: False>)
48 """
49 out = torch.empty_like(input)
50 return _launch_leaky_relu_kernel(input, out, negative_slope)
53def leaky_relu_out(
54 input: torch.Tensor, negative_slope: float = 0.01, out: torch.Tensor = None
55):
56 """
57 ATen: ('leaky_relu.out', <Autograd.disable: False>)
58 """
59 if out is None:
60 raise ValueError("Argument 'out' must be provided for leaky_relu_out.")
61 return _launch_leaky_relu_kernel(input, out, negative_slope)