Coverage for src/flag_gems/ops/hardswish_.py: 57%

44 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-25 02:48 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@triton.jit 

14def hardswish_kernel_(x_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

15 pid = tl.program_id(axis=0) 

16 block_start = pid * BLOCK_SIZE 

17 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

18 mask = offsets < n_elements 

19 

20 x = tl.load(x_ptr + offsets, mask=mask) 

21 

22 three = 3.0 

23 six = 6.0 

24 zero = 0.0 

25 

26 tmp = x + three 

27 tmp = tl.maximum(tmp, zero) 

28 tmp = tl.minimum(tmp, six) 

29 y = x * (tmp / six) 

30 

31 tl.store(x_ptr + offsets, y, mask=mask) 

32 

33 

34def hardswish_(*args, **kwargs): 

35 logger.debug("GEMS HARDSWISH_") 

36 if len(args) >= 1: 

37 x = args[0] 

38 else: 

39 x = kwargs.get("input", kwargs.get("self", None)) 

40 

41 if x is None: 

42 raise ValueError("hardswish_: expected a Tensor as the first argument") 

43 if not isinstance(x, torch.Tensor): 

44 raise TypeError("hardswish_: expected a Tensor") 

45 if not x.is_floating_point(): 

46 raise TypeError("hardswish_: expected a floating point tensor") 

47 

48 orig = x 

49 x_work = x if x.is_contiguous() else x.contiguous() 

50 

51 n_elements = x_work.numel() 

52 if n_elements == 0: 

53 return orig 

54 

55 BLOCK_SIZE = 1024 

56 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

57 

58 with torch_device_fn.device(x_work.device): 

59 hardswish_kernel_[grid](x_work, n_elements, BLOCK_SIZE=BLOCK_SIZE) 

60 

61 if x_work.data_ptr() != orig.data_ptr(): 

62 orig.copy_(x_work) 

63 

64 return orig