Coverage for src/flag_gems/experimental_ops/leaky_relu_.py: 0%
44 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def leaky_relu_(
8 x_ptr, # *Pointer* to input tensor data (modified in-place).
9 n_elements, # Number of elements to process.
10 negative_slope, # Scalar negative slope.
11 BLOCK_SIZE: tl.constexpr,
12):
13 pid = tl.program_id(axis=0)
14 block_start = pid * BLOCK_SIZE
15 offsets = block_start + tl.arange(0, BLOCK_SIZE)
16 mask = offsets < n_elements
18 x = tl.load(x_ptr + offsets, mask=mask)
19 y = tl.where(x >= 0, x, x * negative_slope)
20 tl.store(x_ptr + offsets, y, mask=mask)
23_leaky_relu_kernel = leaky_relu_
26def leaky_relu_(*args, **kwargs):
27 # Parse arguments: expect (input, negative_slope=0.01)
28 if len(args) >= 1:
29 x = args[0]
30 else:
31 x = kwargs.get("self", kwargs.get("input", None))
32 if x is None:
33 raise TypeError("leaky_relu_ expected a tensor as the first argument")
35 negative_slope = 0.01
36 if len(args) >= 2:
37 negative_slope = args[1]
38 else:
39 negative_slope = kwargs.get("negative_slope", negative_slope)
41 if isinstance(negative_slope, torch.Tensor):
42 negative_slope = negative_slope.item()
43 negative_slope = float(negative_slope)
45 # Fallbacks for unsupported environments/dtypes
46 if not isinstance(x, torch.Tensor):
47 raise TypeError("leaky_relu_ expected a torch.Tensor")
48 if not x.is_cuda or x.numel() == 0:
49 return torch.ops.aten.leaky_relu_(x, negative_slope)
51 # For dtypes not well supported by Triton math, fallback to PyTorch
52 supported_dtypes = (torch.float16, torch.bfloat16, torch.float32)
53 if x.dtype not in supported_dtypes:
54 return torch.ops.aten.leaky_relu_(x, negative_slope)
56 # Ensure contiguous memory for in-place kernel; otherwise operate on a temp and copy back.
57 if not x.is_contiguous():
58 tmp = x.contiguous()
59 n_elements = tmp.numel()
60 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
61 _leaky_relu_kernel[grid](tmp, n_elements, negative_slope, BLOCK_SIZE=1024)
62 x.copy_(tmp)
63 return x
65 # Launch Triton kernel in-place
66 n_elements = x.numel()
67 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
68 _leaky_relu_kernel[grid](x, n_elements, negative_slope, BLOCK_SIZE=1024)
69 return x