Coverage for src/flag_gems/experimental_ops/relu.py: 0%

39 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-12 02:21 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def relu_kernel( 

8 input_ptr, # Pointer to input tensor 

9 output_ptr, # Pointer to output tensor 

10 n_elements, # Number of elements 

11 COMPUTE_FP32: tl.constexpr, # Whether to upcast to fp32 for computation 

12 BLOCK_SIZE: tl.constexpr, 

13): 

14 pid = tl.program_id(axis=0) 

15 block_start = pid * BLOCK_SIZE 

16 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

17 mask = offsets < n_elements 

18 

19 x = tl.load(input_ptr + offsets, mask=mask, other=0) 

20 

21 if COMPUTE_FP32: 

22 x_f32 = x.to(tl.float32) 

23 y_f32 = tl.maximum(x_f32, 0.0) 

24 y = y_f32.to(x.dtype) 

25 else: 

26 y = tl.maximum(x, 0) 

27 

28 tl.store(output_ptr + offsets, y, mask=mask) 

29 

30 

31def relu(input: torch.Tensor) -> torch.Tensor: 

32 if not isinstance(input, torch.Tensor): 

33 raise TypeError("input must be a torch.Tensor") 

34 

35 if input.is_complex(): 

36 raise TypeError("relu does not support complex tensors.") 

37 

38 if not input.is_cuda: 

39 raise RuntimeError( 

40 "Triton kernels require CUDA tensors. Move the tensor to a CUDA device." 

41 ) 

42 

43 dtype = input.dtype 

44 

45 # Handle boolean tensors: ReLU is identity 

46 if dtype == torch.bool: 

47 return input.clone() 

48 

49 # Determine computation path 

50 compute_in_fp32 = False 

51 if input.is_floating_point(): 

52 if dtype in (torch.float16, torch.bfloat16): 

53 compute_in_fp32 = True 

54 else: 

55 compute_in_fp32 = False 

56 else: 

57 # Integer tensors handled in native dtype (no fp32 upcast) 

58 compute_in_fp32 = False 

59 

60 x = input.contiguous() 

61 out = torch.empty_like(x) 

62 

63 n_elements = x.numel() 

64 BLOCK_SIZE = 1024 

65 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

66 

67 relu_kernel[grid]( 

68 x, 

69 out, 

70 n_elements, 

71 COMPUTE_FP32=compute_in_fp32, 

72 BLOCK_SIZE=BLOCK_SIZE, 

73 ) 

74 

75 return out