Coverage for src/flag_gems/experimental_ops/hardsigmoid.py: 0%
29 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def hardsigmoid_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
13 x = tl.load(x_ptr + offsets, mask=mask)
14 xf = x.to(tl.float32)
15 y = xf * (1.0 / 6.0) + 0.5
16 y = tl.minimum(tl.maximum(y, 0.0), 1.0)
17 y = y.to(x.dtype)
19 tl.store(out_ptr + offsets, y, mask=mask)
22def hardsigmoid(x: torch.Tensor):
23 out = torch.empty_like(x)
24 assert x.is_cuda and out.is_cuda
25 n_elements = x.numel()
26 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
27 hardsigmoid_kernel[grid](x, out, n_elements, BLOCK_SIZE=1024)
28 return out
31def hardsigmoid_out(x: torch.Tensor, out: torch.Tensor):
32 assert x.is_cuda and out.is_cuda
33 assert x.numel() == out.numel()
34 n_elements = x.numel()
35 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
36 hardsigmoid_kernel[grid](x, out, n_elements, BLOCK_SIZE=1024)
37 return out