Coverage for src/flag_gems/experimental_ops/hardswish.py: 0%

74 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-17 02:35 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def hardswish_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 

13 x = tl.load(x_ptr + offsets, mask=mask, other=0.0) 

14 x32 = x.to(tl.float32) 

15 

16 lower = x32 <= -3.0 

17 upper = x32 >= 3.0 

18 mid = (~lower) & (~upper) 

19 

20 res32 = tl.zeros_like(x32) 

21 res32 = tl.where(upper, x32, res32) 

22 res32 = tl.where(mid, (x32 * (x32 + 3.0)) / 6.0, res32) 

23 # lower region already zero 

24 

25 res = res32.to(x.dtype) 

26 tl.store(out_ptr + offsets, res, mask=mask) 

27 

28 

29def _launch_hardswish(x: torch.Tensor, out: torch.Tensor): 

30 n_elements = x.numel() 

31 if n_elements == 0: 

32 return 

33 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

34 hardswish_kernel[grid](x, out, n_elements, BLOCK_SIZE=1024) 

35 

36 

37def _parse_input_tensor(*args, **kwargs) -> torch.Tensor: 

38 if len(args) >= 1 and isinstance(args[0], torch.Tensor): 

39 return args[0] 

40 if "self" in kwargs and isinstance(kwargs["self"], torch.Tensor): 

41 return kwargs["self"] 

42 if "input" in kwargs and isinstance(kwargs["input"], torch.Tensor): 

43 return kwargs["input"] 

44 raise ValueError( 

45 "Expected input tensor as the first positional argument or as 'self'/'input' keyword argument." 

46 ) 

47 

48 

49def _parse_out_tensor(*args, **kwargs) -> torch.Tensor: 

50 if len(args) >= 2 and isinstance(args[1], torch.Tensor): 

51 return args[1] 

52 if "out" in kwargs and isinstance(kwargs["out"], torch.Tensor): 

53 return kwargs["out"] 

54 raise ValueError( 

55 "Expected 'out' tensor as the second positional argument or as 'out' keyword argument." 

56 ) 

57 

58 

59def _ensure_cuda_tensor(t: torch.Tensor, name: str): 

60 if not isinstance(t, torch.Tensor): 

61 raise TypeError(f"{name} must be a torch.Tensor") 

62 if not t.is_cuda: 

63 raise ValueError(f"{name} must be a CUDA tensor (got device {t.device})") 

64 

65 

66_supported_dtypes = {torch.float16, torch.bfloat16, torch.float32} 

67 

68 

69def _hardswish_impl(x: torch.Tensor, out: torch.Tensor = None): 

70 _ensure_cuda_tensor(x, "input") 

71 if out is not None: 

72 _ensure_cuda_tensor(out, "out") 

73 if out.shape != x.shape: 

74 raise ValueError(f"out shape {out.shape} must match input shape {x.shape}") 

75 

76 x_co = x.contiguous() 

77 compute_dtype = x_co.dtype if x_co.dtype in _supported_dtypes else torch.float32 

78 x_work = x_co if x_co.dtype == compute_dtype else x_co.to(compute_dtype) 

79 

80 if out is None: 

81 final_out = torch.empty_like(x) # preserve layout/strides of input 

82 else: 

83 final_out = out 

84 

85 can_write_direct = ( 

86 final_out.is_contiguous() 

87 and final_out.device == x.device 

88 and final_out.dtype in _supported_dtypes 

89 and final_out.dtype == compute_dtype 

90 ) 

91 

92 if can_write_direct: 

93 out_work = final_out 

94 _launch_hardswish(x_work, out_work) 

95 return final_out 

96 else: 

97 out_work = torch.empty(x_work.shape, dtype=compute_dtype, device=x_work.device) 

98 _launch_hardswish(x_work, out_work) 

99 final_out.copy_(out_work.to(final_out.dtype)) 

100 return final_out 

101 

102 

103def hardswish(*args, **kwargs): 

104 x = _parse_input_tensor(*args, **kwargs) 

105 return _hardswish_impl(x) 

106 

107 

108def hardswish_out(*args, **kwargs): 

109 x = _parse_input_tensor(*args, **kwargs) 

110 out = _parse_out_tensor(*args, **kwargs) 

111 _hardswish_impl(x, out) 

112 return out