Coverage for src/flag_gems/experimental_ops/hardsigmoid_.py: 0%
42 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def hardsigmoid_(x_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
13 x = tl.load(x_ptr + offsets, mask=mask)
14 y = x + 3.0
15 y = tl.maximum(y, 0.0)
16 y = tl.minimum(y, 6.0)
17 y = y / 6.0
18 tl.store(x_ptr + offsets, y, mask=mask)
21_hardsigmoid_triton = hardsigmoid_
24def hardsigmoid_(*args, **kwargs):
25 # Extract input tensor (supports positional or keyword: 'input' or 'self')
26 x = None
27 if len(args) >= 1:
28 x = args[0]
29 else:
30 x = kwargs.get("input", kwargs.get("self", None))
31 if x is None:
32 raise ValueError("hardsigmoid_ expects a tensor as the first argument.")
33 if not isinstance(x, torch.Tensor):
34 raise TypeError("hardsigmoid_ expects a torch.Tensor as input.")
35 if not x.is_floating_point():
36 raise TypeError("hardsigmoid_ only supports floating point tensors.")
37 if x.device.type != "cuda":
38 raise RuntimeError("hardsigmoid_ Triton kernel requires a CUDA tensor.")
40 BLOCK_SIZE = 1024
42 def launch(t: torch.Tensor):
43 n_elements = t.numel()
44 if n_elements == 0:
45 return
46 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
47 _hardsigmoid_triton[grid](t, n_elements, BLOCK_SIZE=BLOCK_SIZE)
49 if not x.is_contiguous():
50 tmp = x.contiguous()
51 launch(tmp)
52 x.copy_(tmp)
53 else:
54 launch(x)
56 return x