Coverage for src/flag_gems/experimental_ops/relu.py: 0%
39 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def relu_kernel(
8 input_ptr, # Pointer to input tensor
9 output_ptr, # Pointer to output tensor
10 n_elements, # Number of elements
11 COMPUTE_FP32: tl.constexpr, # Whether to upcast to fp32 for computation
12 BLOCK_SIZE: tl.constexpr,
13):
14 pid = tl.program_id(axis=0)
15 block_start = pid * BLOCK_SIZE
16 offsets = block_start + tl.arange(0, BLOCK_SIZE)
17 mask = offsets < n_elements
19 x = tl.load(input_ptr + offsets, mask=mask, other=0)
21 if COMPUTE_FP32:
22 x_f32 = x.to(tl.float32)
23 y_f32 = tl.maximum(x_f32, 0.0)
24 y = y_f32.to(x.dtype)
25 else:
26 y = tl.maximum(x, 0)
28 tl.store(output_ptr + offsets, y, mask=mask)
31def relu(input: torch.Tensor) -> torch.Tensor:
32 if not isinstance(input, torch.Tensor):
33 raise TypeError("input must be a torch.Tensor")
35 if input.is_complex():
36 raise TypeError("relu does not support complex tensors.")
38 if not input.is_cuda:
39 raise RuntimeError(
40 "Triton kernels require CUDA tensors. Move the tensor to a CUDA device."
41 )
43 dtype = input.dtype
45 # Handle boolean tensors: ReLU is identity
46 if dtype == torch.bool:
47 return input.clone()
49 # Determine computation path
50 compute_in_fp32 = False
51 if input.is_floating_point():
52 if dtype in (torch.float16, torch.bfloat16):
53 compute_in_fp32 = True
54 else:
55 compute_in_fp32 = False
56 else:
57 # Integer tensors handled in native dtype (no fp32 upcast)
58 compute_in_fp32 = False
60 x = input.contiguous()
61 out = torch.empty_like(x)
63 n_elements = x.numel()
64 BLOCK_SIZE = 1024
65 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
67 relu_kernel[grid](
68 x,
69 out,
70 n_elements,
71 COMPUTE_FP32=compute_in_fp32,
72 BLOCK_SIZE=BLOCK_SIZE,
73 )
75 return out