Coverage for src/flag_gems/experimental_ops/prelu.py: 0%
59 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def prelu(
8 x_ptr, # *Pointer* to input tensor.
9 w_ptr, # *Pointer* to weight tensor (scalar or per-channel vector).
10 out_ptr, # *Pointer* to output tensor.
11 n_elements, # Total number of elements in input.
12 S, # Spatial size = product of dims after channel dim (or 1 if none).
13 C, # Number of channels (or 1).
14 w_is_scalar: tl.constexpr, # Whether weight is a single scalar.
15 BLOCK_SIZE: tl.constexpr,
16):
17 pid = tl.program_id(axis=0)
18 block_start = pid * BLOCK_SIZE
19 offsets = block_start + tl.arange(0, BLOCK_SIZE)
20 mask = offsets < n_elements
22 x = tl.load(x_ptr + offsets, mask=mask)
24 if w_is_scalar:
25 alpha = tl.load(w_ptr) # scalar
26 y = tl.where(x >= 0, x, alpha * x)
27 else:
28 c = (offsets // S) % C
29 alpha = tl.load(w_ptr + c, mask=mask)
30 y = tl.where(x >= 0, x, alpha * x)
32 tl.store(out_ptr + offsets, y, mask=mask)
35# Keep a reference to the Triton kernel before defining the Python wrapper with the same name.
36prelu_kernel = prelu
39def prelu(*args, **kwargs):
40 # Extract inputs
41 if len(args) >= 2:
42 x, weight = args[0], args[1]
43 else:
44 x = kwargs.get("input", kwargs.get("self"))
45 weight = kwargs.get("weight")
46 if x is None or weight is None:
47 raise ValueError("prelu expects (input, weight) as arguments.")
49 if not (x.is_cuda and weight.is_cuda):
50 raise AssertionError("Tensors must be CUDA tensors.")
52 # Ensure dtype match
53 if weight.dtype != x.dtype:
54 weight = weight.to(dtype=x.dtype)
56 # Ensure contiguous
57 x = x.contiguous()
58 weight = weight.contiguous()
60 out = torch.empty_like(x)
62 n_elements = x.numel()
63 if n_elements == 0:
64 return out
66 # Determine channel count C and spatial size S
67 ndim = x.dim()
68 if weight.numel() == 1:
69 C = 1
70 S = 1
71 w_is_scalar = True
72 else:
73 if ndim == 0:
74 raise AssertionError("Non-scalar weight provided for a 0-dim input.")
75 if ndim == 1:
76 C = x.shape[0]
77 S = 1
78 else:
79 C = x.shape[1]
80 S = 1
81 if ndim > 2:
82 for d in x.shape[2:]:
83 S *= d
84 if weight.numel() != C:
85 raise AssertionError(
86 f"Weight numel ({weight.numel()}) must equal channel dimension size ({C})."
87 )
88 w_is_scalar = False
90 # Make sure S and C are at least 1 to avoid div/mod by zero in kernel math
91 C = max(int(C), 1)
92 S = max(int(S), 1)
94 BLOCK_SIZE = 1024
95 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
97 prelu_kernel[grid](
98 x, weight, out, n_elements, S, C, w_is_scalar=w_is_scalar, BLOCK_SIZE=BLOCK_SIZE
99 )
100 return out