Coverage for src/flag_gems/experimental_ops/hardshrink.py: 0%

44 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-23 02:03 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def hardshrink_kernel(x_ptr, out_ptr, n_elements, lambd, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 

13 x = tl.load(x_ptr + offsets, mask=mask, other=0.0) 

14 threshold = lambd 

15 keep = (x > threshold) | (x < -threshold) 

16 y = tl.where(keep, x, 0.0) 

17 tl.store(out_ptr + offsets, y, mask=mask) 

18 

19 

20def _hardshrink_launch(x: torch.Tensor, lambd: float, out: torch.Tensor): 

21 assert x.is_cuda, "Input tensor must be on CUDA device" 

22 assert out.is_cuda, "Output tensor must be on CUDA device" 

23 assert ( 

24 x.numel() == out.numel() 

25 ), "Input and output must have the same number of elements" 

26 assert x.dtype == out.dtype, "Input and output must have the same dtype" 

27 assert x.is_floating_point(), "hardshrink only supports floating point dtypes" 

28 

29 n_elements = x.numel() 

30 BLOCK_SIZE = 1024 

31 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

32 hardshrink_kernel[grid](x, out, n_elements, float(lambd), BLOCK_SIZE=BLOCK_SIZE) 

33 

34 

35def hardshrink(x: torch.Tensor, lambd: float = 0.5) -> torch.Tensor: 

36 x_c = x.contiguous() 

37 out = torch.empty_like(x_c) 

38 _hardshrink_launch(x_c, lambd, out) 

39 return out 

40 

41 

42def hardshrink_out( 

43 x: torch.Tensor, lambd: float = 0.5, out: torch.Tensor = None 

44) -> torch.Tensor: 

45 x_c = x.contiguous() 

46 if out is None: 

47 out = torch.empty_like(x_c) 

48 _hardshrink_launch(x_c, lambd, out) 

49 return out 

50 # Ensure output is allocated correctly 

51 assert out.is_cuda, "Output tensor must be on CUDA device" 

52 assert out.dtype == x_c.dtype, "Output dtype must match input dtype" 

53 assert out.shape == x_c.shape, "Output shape must match input shape" 

54 

55 if out.is_contiguous(): 

56 _hardshrink_launch(x_c, lambd, out) 

57 else: 

58 tmp = torch.empty_like(x_c) 

59 _hardshrink_launch(x_c, lambd, tmp) 

60 out.copy_(tmp) 

61 return out