Coverage for src/flag_gems/experimental_ops/prelu.py: 0%

59 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-19 02:32 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def prelu( 

8 x_ptr, # *Pointer* to input tensor. 

9 w_ptr, # *Pointer* to weight tensor (scalar or per-channel vector). 

10 out_ptr, # *Pointer* to output tensor. 

11 n_elements, # Total number of elements in input. 

12 S, # Spatial size = product of dims after channel dim (or 1 if none). 

13 C, # Number of channels (or 1). 

14 w_is_scalar: tl.constexpr, # Whether weight is a single scalar. 

15 BLOCK_SIZE: tl.constexpr, 

16): 

17 pid = tl.program_id(axis=0) 

18 block_start = pid * BLOCK_SIZE 

19 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

20 mask = offsets < n_elements 

21 

22 x = tl.load(x_ptr + offsets, mask=mask) 

23 

24 if w_is_scalar: 

25 alpha = tl.load(w_ptr) # scalar 

26 y = tl.where(x >= 0, x, alpha * x) 

27 else: 

28 c = (offsets // S) % C 

29 alpha = tl.load(w_ptr + c, mask=mask) 

30 y = tl.where(x >= 0, x, alpha * x) 

31 

32 tl.store(out_ptr + offsets, y, mask=mask) 

33 

34 

35# Keep a reference to the Triton kernel before defining the Python wrapper with the same name. 

36prelu_kernel = prelu 

37 

38 

39def prelu(*args, **kwargs): 

40 # Extract inputs 

41 if len(args) >= 2: 

42 x, weight = args[0], args[1] 

43 else: 

44 x = kwargs.get("input", kwargs.get("self")) 

45 weight = kwargs.get("weight") 

46 if x is None or weight is None: 

47 raise ValueError("prelu expects (input, weight) as arguments.") 

48 

49 if not (x.is_cuda and weight.is_cuda): 

50 raise AssertionError("Tensors must be CUDA tensors.") 

51 

52 # Ensure dtype match 

53 if weight.dtype != x.dtype: 

54 weight = weight.to(dtype=x.dtype) 

55 

56 # Ensure contiguous 

57 x = x.contiguous() 

58 weight = weight.contiguous() 

59 

60 out = torch.empty_like(x) 

61 

62 n_elements = x.numel() 

63 if n_elements == 0: 

64 return out 

65 

66 # Determine channel count C and spatial size S 

67 ndim = x.dim() 

68 if weight.numel() == 1: 

69 C = 1 

70 S = 1 

71 w_is_scalar = True 

72 else: 

73 if ndim == 0: 

74 raise AssertionError("Non-scalar weight provided for a 0-dim input.") 

75 if ndim == 1: 

76 C = x.shape[0] 

77 S = 1 

78 else: 

79 C = x.shape[1] 

80 S = 1 

81 if ndim > 2: 

82 for d in x.shape[2:]: 

83 S *= d 

84 if weight.numel() != C: 

85 raise AssertionError( 

86 f"Weight numel ({weight.numel()}) must equal channel dimension size ({C})." 

87 ) 

88 w_is_scalar = False 

89 

90 # Make sure S and C are at least 1 to avoid div/mod by zero in kernel math 

91 C = max(int(C), 1) 

92 S = max(int(S), 1) 

93 

94 BLOCK_SIZE = 1024 

95 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

96 

97 prelu_kernel[grid]( 

98 x, weight, out, n_elements, S, C, w_is_scalar=w_is_scalar, BLOCK_SIZE=BLOCK_SIZE 

99 ) 

100 return out