Coverage for src/flag_gems/experimental_ops/relu_.py: 0%

28 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-26 15:32 +0800

1import torch # noqa: F401 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def relu_( 

8 x_ptr, # *Pointer* to input/output tensor (in-place). 

9 n_elements, # Number of elements. 

10 BLOCK_SIZE: tl.constexpr, 

11): 

12 pid = tl.program_id(axis=0) 

13 block_start = pid * BLOCK_SIZE 

14 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

15 mask = offsets < n_elements 

16 

17 x = tl.load(x_ptr + offsets, mask=mask) 

18 zero = x * 0 

19 y = tl.where(x > 0, x, zero) 

20 tl.store(x_ptr + offsets, y, mask=mask) 

21 

22 

23# Keep a reference to the Triton kernel before defining the Python wrapper with the same name. 

24relu__kernel = relu_ 

25 

26 

27def relu_(*args, **kwargs): 

28 # Expect the first positional argument to be the tensor. 

29 x = args[0] if len(args) > 0 else kwargs.get("input", kwargs.get("x", None)) 

30 if x is None: 

31 raise ValueError("relu_ expects a tensor as the first positional argument.") 

32 if not x.is_cuda: 

33 raise ValueError("relu_ Triton implementation requires a CUDA tensor.") 

34 if not x.is_contiguous(): 

35 raise ValueError("relu_ Triton implementation requires a contiguous tensor.") 

36 

37 n_elements = x.numel() 

38 if n_elements == 0: 

39 return x 

40 

41 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

42 relu__kernel[grid](x, n_elements, BLOCK_SIZE=1024) 

43 return x