Coverage for src/flag_gems/experimental_ops/hardswish.py: 0%
74 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def hardswish_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
13 x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
14 x32 = x.to(tl.float32)
16 lower = x32 <= -3.0
17 upper = x32 >= 3.0
18 mid = (~lower) & (~upper)
20 res32 = tl.zeros_like(x32)
21 res32 = tl.where(upper, x32, res32)
22 res32 = tl.where(mid, (x32 * (x32 + 3.0)) / 6.0, res32)
23 # lower region already zero
25 res = res32.to(x.dtype)
26 tl.store(out_ptr + offsets, res, mask=mask)
29def _launch_hardswish(x: torch.Tensor, out: torch.Tensor):
30 n_elements = x.numel()
31 if n_elements == 0:
32 return
33 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
34 hardswish_kernel[grid](x, out, n_elements, BLOCK_SIZE=1024)
37def _parse_input_tensor(*args, **kwargs) -> torch.Tensor:
38 if len(args) >= 1 and isinstance(args[0], torch.Tensor):
39 return args[0]
40 if "self" in kwargs and isinstance(kwargs["self"], torch.Tensor):
41 return kwargs["self"]
42 if "input" in kwargs and isinstance(kwargs["input"], torch.Tensor):
43 return kwargs["input"]
44 raise ValueError(
45 "Expected input tensor as the first positional argument or as 'self'/'input' keyword argument."
46 )
49def _parse_out_tensor(*args, **kwargs) -> torch.Tensor:
50 if len(args) >= 2 and isinstance(args[1], torch.Tensor):
51 return args[1]
52 if "out" in kwargs and isinstance(kwargs["out"], torch.Tensor):
53 return kwargs["out"]
54 raise ValueError(
55 "Expected 'out' tensor as the second positional argument or as 'out' keyword argument."
56 )
59def _ensure_cuda_tensor(t: torch.Tensor, name: str):
60 if not isinstance(t, torch.Tensor):
61 raise TypeError(f"{name} must be a torch.Tensor")
62 if not t.is_cuda:
63 raise ValueError(f"{name} must be a CUDA tensor (got device {t.device})")
66_supported_dtypes = {torch.float16, torch.bfloat16, torch.float32}
69def _hardswish_impl(x: torch.Tensor, out: torch.Tensor = None):
70 _ensure_cuda_tensor(x, "input")
71 if out is not None:
72 _ensure_cuda_tensor(out, "out")
73 if out.shape != x.shape:
74 raise ValueError(f"out shape {out.shape} must match input shape {x.shape}")
76 x_co = x.contiguous()
77 compute_dtype = x_co.dtype if x_co.dtype in _supported_dtypes else torch.float32
78 x_work = x_co if x_co.dtype == compute_dtype else x_co.to(compute_dtype)
80 if out is None:
81 final_out = torch.empty_like(x) # preserve layout/strides of input
82 else:
83 final_out = out
85 can_write_direct = (
86 final_out.is_contiguous()
87 and final_out.device == x.device
88 and final_out.dtype in _supported_dtypes
89 and final_out.dtype == compute_dtype
90 )
92 if can_write_direct:
93 out_work = final_out
94 _launch_hardswish(x_work, out_work)
95 return final_out
96 else:
97 out_work = torch.empty(x_work.shape, dtype=compute_dtype, device=x_work.device)
98 _launch_hardswish(x_work, out_work)
99 final_out.copy_(out_work.to(final_out.dtype))
100 return final_out
103def hardswish(*args, **kwargs):
104 x = _parse_input_tensor(*args, **kwargs)
105 return _hardswish_impl(x)
108def hardswish_out(*args, **kwargs):
109 x = _parse_input_tensor(*args, **kwargs)
110 out = _parse_out_tensor(*args, **kwargs)
111 _hardswish_impl(x, out)
112 return out