Coverage for src/flag_gems/experimental_ops/hardtanh_.py: 0%
51 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def hardtanh_(
8 x_ptr, # *Pointer* to input/output tensor (in-place).
9 n_elements, # Number of elements.
10 min_val, # Minimum clamp value (scalar).
11 max_val, # Maximum clamp value (scalar).
12 BLOCK_SIZE: tl.constexpr,
13):
14 pid = tl.program_id(axis=0)
15 block_start = pid * BLOCK_SIZE
16 offsets = block_start + tl.arange(0, BLOCK_SIZE)
17 mask = offsets < n_elements
19 x = tl.load(x_ptr + offsets, mask=mask, other=0)
21 # Cast min/max to tensor dtype
22 min_v = tl.full([1], min_val, dtype=x.dtype)
23 max_v = tl.full([1], max_val, dtype=x.dtype)
25 x = tl.minimum(x, max_v)
26 x = tl.maximum(x, min_v)
28 tl.store(x_ptr + offsets, x, mask=mask)
31# Keep a reference to the Triton kernel before defining the Python wrapper of the same name
32hardtanh___kernel = hardtanh_
35def hardtanh_(*args, **kwargs):
36 # Parse arguments: expected signature hardtanh_(x, min_val=-1.0, max_val=1.0)
37 if len(args) == 0 and "input" not in kwargs and "self" not in kwargs:
38 raise TypeError("hardtanh_ expected at least 1 argument: a CUDA tensor")
40 # Accept common naming: positional 0, or keyword 'input'/'self'
41 x = None
42 if len(args) >= 1:
43 x = args[0]
44 elif "input" in kwargs:
45 x = kwargs["input"]
46 elif "self" in kwargs:
47 x = kwargs["self"]
49 # Defaults
50 min_val = -1.0
51 max_val = 1.0
53 # Override from positional args if provided
54 if len(args) >= 2:
55 min_val = args[1]
56 if len(args) >= 3:
57 max_val = args[2]
59 # Override from kwargs if provided
60 if "min_val" in kwargs and kwargs["min_val"] is not None:
61 min_val = kwargs["min_val"]
62 if "max_val" in kwargs and kwargs["max_val"] is not None:
63 max_val = kwargs["max_val"]
65 if not isinstance(x, torch.Tensor):
66 raise TypeError("hardtanh_ expects a torch.Tensor as the first argument")
68 # Fallback for unsupported device/dtypes
69 if not x.is_cuda:
70 # CPU fallback using PyTorch
71 return torch.clamp_(x, min=min_val, max=max_val)
73 if not x.is_floating_point():
74 # For non-floating types, use PyTorch fallback to preserve semantics
75 return torch.clamp_(x, min=min_val, max=max_val)
77 # Require contiguous memory for in-place update
78 if not x.is_contiguous():
79 # To preserve in-place semantics on non-contiguous tensors, use PyTorch
80 return torch.clamp_(x, min=min_val, max=max_val)
82 n_elements = x.numel()
83 if n_elements == 0:
84 return x
86 BLOCK_SIZE = 1024
87 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
89 # Launch Triton kernel
90 hardtanh___kernel[grid](
91 x, n_elements, float(min_val), float(max_val), BLOCK_SIZE=BLOCK_SIZE # in-place
92 )
93 return x