Coverage for src/flag_gems/ops/prelu.py: 65%

62 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8logger = logging.getLogger(__name__) 

9 

10 

11@triton.jit 

12def prelu( 

13 x_ptr, # *Pointer* to input tensor. 

14 w_ptr, # *Pointer* to weight tensor (scalar or per-channel vector). 

15 out_ptr, # *Pointer* to output tensor. 

16 n_elements, # Total number of elements in input. 

17 S, # Spatial size = product of dims after channel dim (or 1 if none). 

18 C, # Number of channels (or 1). 

19 w_is_scalar: tl.constexpr, # Whether weight is a single scalar. 

20 BLOCK_SIZE: tl.constexpr, 

21): 

22 pid = tl.program_id(axis=0) 

23 block_start = pid * BLOCK_SIZE 

24 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

25 mask = offsets < n_elements 

26 

27 x = tl.load(x_ptr + offsets, mask=mask) 

28 

29 if w_is_scalar: 

30 alpha = tl.load(w_ptr) # scalar 

31 y = tl.where(x >= 0, x, alpha * x) 

32 else: 

33 c = (offsets // S) % C 

34 alpha = tl.load(w_ptr + c, mask=mask) 

35 y = tl.where(x >= 0, x, alpha * x) 

36 

37 tl.store(out_ptr + offsets, y, mask=mask) 

38 

39 

40# Keep a reference to the Triton kernel before defining the Python wrapper with the same name. 

41prelu_kernel = prelu 

42 

43 

44def prelu(*args, **kwargs): 

45 logger.debug("GEMS PRELU") 

46 # Extract inputs 

47 if len(args) >= 2: 

48 x, weight = args[0], args[1] 

49 else: 

50 x = kwargs.get("input", kwargs.get("self")) 

51 weight = kwargs.get("weight") 

52 if x is None or weight is None: 

53 raise ValueError("prelu expects (input, weight) as arguments.") 

54 

55 if not (x.is_cuda and weight.is_cuda): 

56 raise AssertionError("Tensors must be CUDA tensors.") 

57 

58 # Ensure dtype match 

59 if weight.dtype != x.dtype: 

60 weight = weight.to(dtype=x.dtype) 

61 

62 # Ensure contiguous 

63 x = x.contiguous() 

64 weight = weight.contiguous() 

65 

66 out = torch.empty_like(x) 

67 

68 n_elements = x.numel() 

69 if n_elements == 0: 

70 return out 

71 

72 # Determine channel count C and spatial size S 

73 ndim = x.dim() 

74 if weight.numel() == 1: 

75 C = 1 

76 S = 1 

77 w_is_scalar = True 

78 else: 

79 if ndim == 0: 

80 raise AssertionError("Non-scalar weight provided for a 0-dim input.") 

81 if ndim == 1: 

82 C = x.shape[0] 

83 S = 1 

84 else: 

85 C = x.shape[1] 

86 S = 1 

87 if ndim > 2: 

88 for d in x.shape[2:]: 

89 S *= d 

90 if weight.numel() != C: 

91 raise AssertionError( 

92 f"Weight numel ({weight.numel()}) must equal channel dimension size ({C})." 

93 ) 

94 w_is_scalar = False 

95 

96 # Make sure S and C are at least 1 to avoid div/mod by zero in kernel math 

97 C = max(int(C), 1) 

98 S = max(int(S), 1) 

99 

100 BLOCK_SIZE = 1024 

101 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

102 

103 prelu_kernel[grid]( 

104 x, weight, out, n_elements, S, C, w_is_scalar=w_is_scalar, BLOCK_SIZE=BLOCK_SIZE 

105 ) 

106 return out