Coverage for src/flag_gems/experimental_ops/hardshrink.py: 0%
44 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def hardshrink_kernel(x_ptr, out_ptr, n_elements, lambd, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
13 x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
14 threshold = lambd
15 keep = (x > threshold) | (x < -threshold)
16 y = tl.where(keep, x, 0.0)
17 tl.store(out_ptr + offsets, y, mask=mask)
20def _hardshrink_launch(x: torch.Tensor, lambd: float, out: torch.Tensor):
21 assert x.is_cuda, "Input tensor must be on CUDA device"
22 assert out.is_cuda, "Output tensor must be on CUDA device"
23 assert (
24 x.numel() == out.numel()
25 ), "Input and output must have the same number of elements"
26 assert x.dtype == out.dtype, "Input and output must have the same dtype"
27 assert x.is_floating_point(), "hardshrink only supports floating point dtypes"
29 n_elements = x.numel()
30 BLOCK_SIZE = 1024
31 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
32 hardshrink_kernel[grid](x, out, n_elements, float(lambd), BLOCK_SIZE=BLOCK_SIZE)
35def hardshrink(x: torch.Tensor, lambd: float = 0.5) -> torch.Tensor:
36 x_c = x.contiguous()
37 out = torch.empty_like(x_c)
38 _hardshrink_launch(x_c, lambd, out)
39 return out
42def hardshrink_out(
43 x: torch.Tensor, lambd: float = 0.5, out: torch.Tensor = None
44) -> torch.Tensor:
45 x_c = x.contiguous()
46 if out is None:
47 out = torch.empty_like(x_c)
48 _hardshrink_launch(x_c, lambd, out)
49 return out
50 # Ensure output is allocated correctly
51 assert out.is_cuda, "Output tensor must be on CUDA device"
52 assert out.dtype == x_c.dtype, "Output dtype must match input dtype"
53 assert out.shape == x_c.shape, "Output shape must match input shape"
55 if out.is_contiguous():
56 _hardshrink_launch(x_c, lambd, out)
57 else:
58 tmp = torch.empty_like(x_c)
59 _hardshrink_launch(x_c, lambd, tmp)
60 out.copy_(tmp)
61 return out