Coverage for src/flag_gems/ops/prelu.py: 65%
62 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
4import torch
5import triton
6import triton.language as tl
8logger = logging.getLogger(__name__)
11@triton.jit
12def prelu(
13 x_ptr, # *Pointer* to input tensor.
14 w_ptr, # *Pointer* to weight tensor (scalar or per-channel vector).
15 out_ptr, # *Pointer* to output tensor.
16 n_elements, # Total number of elements in input.
17 S, # Spatial size = product of dims after channel dim (or 1 if none).
18 C, # Number of channels (or 1).
19 w_is_scalar: tl.constexpr, # Whether weight is a single scalar.
20 BLOCK_SIZE: tl.constexpr,
21):
22 pid = tl.program_id(axis=0)
23 block_start = pid * BLOCK_SIZE
24 offsets = block_start + tl.arange(0, BLOCK_SIZE)
25 mask = offsets < n_elements
27 x = tl.load(x_ptr + offsets, mask=mask)
29 if w_is_scalar:
30 alpha = tl.load(w_ptr) # scalar
31 y = tl.where(x >= 0, x, alpha * x)
32 else:
33 c = (offsets // S) % C
34 alpha = tl.load(w_ptr + c, mask=mask)
35 y = tl.where(x >= 0, x, alpha * x)
37 tl.store(out_ptr + offsets, y, mask=mask)
40# Keep a reference to the Triton kernel before defining the Python wrapper with the same name.
41prelu_kernel = prelu
44def prelu(*args, **kwargs):
45 logger.debug("GEMS PRELU")
46 # Extract inputs
47 if len(args) >= 2:
48 x, weight = args[0], args[1]
49 else:
50 x = kwargs.get("input", kwargs.get("self"))
51 weight = kwargs.get("weight")
52 if x is None or weight is None:
53 raise ValueError("prelu expects (input, weight) as arguments.")
55 if not (x.is_cuda and weight.is_cuda):
56 raise AssertionError("Tensors must be CUDA tensors.")
58 # Ensure dtype match
59 if weight.dtype != x.dtype:
60 weight = weight.to(dtype=x.dtype)
62 # Ensure contiguous
63 x = x.contiguous()
64 weight = weight.contiguous()
66 out = torch.empty_like(x)
68 n_elements = x.numel()
69 if n_elements == 0:
70 return out
72 # Determine channel count C and spatial size S
73 ndim = x.dim()
74 if weight.numel() == 1:
75 C = 1
76 S = 1
77 w_is_scalar = True
78 else:
79 if ndim == 0:
80 raise AssertionError("Non-scalar weight provided for a 0-dim input.")
81 if ndim == 1:
82 C = x.shape[0]
83 S = 1
84 else:
85 C = x.shape[1]
86 S = 1
87 if ndim > 2:
88 for d in x.shape[2:]:
89 S *= d
90 if weight.numel() != C:
91 raise AssertionError(
92 f"Weight numel ({weight.numel()}) must equal channel dimension size ({C})."
93 )
94 w_is_scalar = False
96 # Make sure S and C are at least 1 to avoid div/mod by zero in kernel math
97 C = max(int(C), 1)
98 S = max(int(S), 1)
100 BLOCK_SIZE = 1024
101 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
103 prelu_kernel[grid](
104 x, weight, out, n_elements, S, C, w_is_scalar=w_is_scalar, BLOCK_SIZE=BLOCK_SIZE
105 )
106 return out