Coverage for src/flag_gems/experimental_ops/hardswish_.py: 0%
42 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def hardswish_(x_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
13 x = tl.load(x_ptr + offsets, mask=mask)
15 three = 3.0
16 six = 6.0
17 zero = 0.0
19 tmp = x + three
20 tmp = tl.maximum(tmp, zero)
21 tmp = tl.minimum(tmp, six)
22 y = x * (tmp / six)
24 tl.store(x_ptr + offsets, y, mask=mask)
27# Preserve a reference to the Triton kernel before defining the Python wrapper with the same name.
28hardswish__kernel = hardswish_
31def hardswish_(*args, **kwargs):
32 # Resolve input tensor from positional or keyword arguments
33 if len(args) >= 1:
34 x = args[0]
35 else:
36 x = kwargs.get("input", kwargs.get("self", None))
38 if x is None:
39 raise ValueError("hardswish_: expected a Tensor as the first argument")
40 if not isinstance(x, torch.Tensor):
41 raise TypeError("hardswish_: expected a Tensor")
42 if not x.is_cuda:
43 raise ValueError("hardswish_: expected a CUDA tensor")
44 if not x.is_floating_point():
45 raise TypeError("hardswish_: expected a floating point tensor")
47 orig = x
48 x_work = x if x.is_contiguous() else x.contiguous()
50 n_elements = x_work.numel()
51 if n_elements == 0:
52 return orig
54 BLOCK_SIZE = 1024
55 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
57 hardswish__kernel[grid](x_work, n_elements, BLOCK_SIZE=BLOCK_SIZE)
59 if x_work.data_ptr() != orig.data_ptr():
60 orig.copy_(x_work)
62 return orig