Coverage for src/flag_gems/experimental_ops/hardtanh.py: 0%
49 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def hardtanh_kernel(
8 x_ptr, out_ptr, n_elements, min_val, max_val, BLOCK_SIZE: tl.constexpr
9):
10 pid = tl.program_id(axis=0)
11 block_start = pid * BLOCK_SIZE
12 offsets = block_start + tl.arange(0, BLOCK_SIZE)
13 mask = offsets < n_elements
15 x = tl.load(x_ptr + offsets, mask=mask)
16 x_fp32 = x.to(tl.float32)
18 min_v = min_val # expected to be float32 scalar
19 max_v = max_val # expected to be float32 scalar
20 x_clamped = tl.maximum(tl.minimum(x_fp32, max_v), min_v)
22 y = x_clamped.to(x.dtype)
23 tl.store(out_ptr + offsets, y, mask=mask)
26def _launch_hardtanh(
27 input: torch.Tensor, output: torch.Tensor, min_val: float, max_val: float
28):
29 assert input.is_cuda and output.is_cuda, "Tensors must be on CUDA device"
30 assert input.device == output.device, "Input and output must be on the same device"
31 assert input.dtype == output.dtype, "Input and output must have the same dtype"
32 n_elements = input.numel()
33 if n_elements == 0:
34 return output
35 BLOCK_SIZE = 1024
36 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
37 hardtanh_kernel[grid](
38 input,
39 output,
40 n_elements,
41 float(min_val),
42 float(max_val),
43 BLOCK_SIZE=BLOCK_SIZE,
44 )
45 return output
48def hardtanh(self: torch.Tensor, min_val: float = -1.0, max_val: float = 1.0):
49 x = self
50 assert x.is_cuda, "Input tensor must be on CUDA device"
51 x_contig = x.contiguous()
52 out = torch.empty_like(x_contig)
53 _launch_hardtanh(x_contig, out, min_val, max_val)
54 # If original tensor wasn't contiguous, we still return a tensor matching input's shape and dtype
55 return out.view_as(x)
58def hardtanh_out(
59 self: torch.Tensor,
60 min_val: float = -1.0,
61 max_val: float = 1.0,
62 out: torch.Tensor = None,
63):
64 x = self
65 assert x.is_cuda, "Input tensor must be on CUDA device"
66 if out is None:
67 out = torch.empty_like(x)
68 assert out.is_cuda, "Output tensor must be on CUDA device"
69 assert out.shape == x.shape, "Output tensor must have the same shape as input"
70 assert out.dtype == x.dtype, "Output tensor must have the same dtype as input"
71 if not out.is_contiguous():
72 # For non-contiguous out, compute into a contiguous buffer then copy back
73 out_contig = torch.empty_like(out.contiguous())
74 _launch_hardtanh(x.contiguous(), out_contig, min_val, max_val)
75 out.copy_(out_contig)
76 return out
77 _launch_hardtanh(x.contiguous(), out, min_val, max_val)
78 return out