Coverage for src/flag_gems/experimental_ops/leaky_relu.py: 0%

36 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-12 02:21 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def _leaky_relu_kernel( 

8 x_ptr, y_ptr, n_elements, negative_slope, BLOCK_SIZE: tl.constexpr 

9): 

10 pid = tl.program_id(axis=0) 

11 block_start = pid * BLOCK_SIZE 

12 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

13 mask = offsets < n_elements 

14 

15 x = tl.load(x_ptr + offsets, mask=mask, other=0) 

16 zero = tl.zeros([BLOCK_SIZE], dtype=x.dtype) 

17 slope = tl.full([BLOCK_SIZE], negative_slope, dtype=x.dtype) 

18 # y = x if x >= 0 else slope * x 

19 # Equivalent, branchless: 

20 y = tl.maximum(x, zero) + slope * tl.minimum(x, zero) 

21 

22 tl.store(y_ptr + offsets, y, mask=mask) 

23 

24 

25def _launch_leaky_relu_kernel( 

26 x: torch.Tensor, out: torch.Tensor, negative_slope: float 

27): 

28 if not x.is_cuda or not out.is_cuda: 

29 raise ValueError("Input and output tensors must be on CUDA device.") 

30 if x.numel() != out.numel(): 

31 raise ValueError("Input and output must have the same number of elements.") 

32 if x.dtype != out.dtype: 

33 raise ValueError("Input and output tensors must have the same dtype.") 

34 if not x.is_contiguous(): 

35 x = x.contiguous() 

36 if not out.is_contiguous(): 

37 raise ValueError("Output tensor must be contiguous.") 

38 

39 n_elements = x.numel() 

40 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

41 _leaky_relu_kernel[grid](x, out, n_elements, float(negative_slope), BLOCK_SIZE=1024) 

42 return out 

43 

44 

45def leaky_relu(input: torch.Tensor, negative_slope: float = 0.01): 

46 """ 

47 ATen: ('leaky_relu', <Autograd.disable: False>) 

48 """ 

49 out = torch.empty_like(input) 

50 return _launch_leaky_relu_kernel(input, out, negative_slope) 

51 

52 

53def leaky_relu_out( 

54 input: torch.Tensor, negative_slope: float = 0.01, out: torch.Tensor = None 

55): 

56 """ 

57 ATen: ('leaky_relu.out', <Autograd.disable: False>) 

58 """ 

59 if out is None: 

60 raise ValueError("Argument 'out' must be provided for leaky_relu_out.") 

61 return _launch_leaky_relu_kernel(input, out, negative_slope)