Coverage for src/flag_gems/experimental_ops/hardswish_.py: 0%

42 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-12 02:21 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def hardswish_(x_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 

13 x = tl.load(x_ptr + offsets, mask=mask) 

14 

15 three = 3.0 

16 six = 6.0 

17 zero = 0.0 

18 

19 tmp = x + three 

20 tmp = tl.maximum(tmp, zero) 

21 tmp = tl.minimum(tmp, six) 

22 y = x * (tmp / six) 

23 

24 tl.store(x_ptr + offsets, y, mask=mask) 

25 

26 

27# Preserve a reference to the Triton kernel before defining the Python wrapper with the same name. 

28hardswish__kernel = hardswish_ 

29 

30 

31def hardswish_(*args, **kwargs): 

32 # Resolve input tensor from positional or keyword arguments 

33 if len(args) >= 1: 

34 x = args[0] 

35 else: 

36 x = kwargs.get("input", kwargs.get("self", None)) 

37 

38 if x is None: 

39 raise ValueError("hardswish_: expected a Tensor as the first argument") 

40 if not isinstance(x, torch.Tensor): 

41 raise TypeError("hardswish_: expected a Tensor") 

42 if not x.is_cuda: 

43 raise ValueError("hardswish_: expected a CUDA tensor") 

44 if not x.is_floating_point(): 

45 raise TypeError("hardswish_: expected a floating point tensor") 

46 

47 orig = x 

48 x_work = x if x.is_contiguous() else x.contiguous() 

49 

50 n_elements = x_work.numel() 

51 if n_elements == 0: 

52 return orig 

53 

54 BLOCK_SIZE = 1024 

55 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

56 

57 hardswish__kernel[grid](x_work, n_elements, BLOCK_SIZE=BLOCK_SIZE) 

58 

59 if x_work.data_ptr() != orig.data_ptr(): 

60 orig.copy_(x_work) 

61 

62 return orig